From 83bce6733b03a4209fe4a62e71fed49a1d4526c5 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sat, 15 Feb 2025 09:20:48 -0500 Subject: [PATCH 01/32] Work in progress adding user defined aggregate function FFI support --- datafusion/ffi/src/arrow_wrappers.rs | 13 +- datafusion/ffi/src/lib.rs | 1 + datafusion/ffi/src/udaf/accumulator.rs | 342 +++++++++++++++++++++++++ datafusion/ffi/src/udaf/mod.rs | 302 ++++++++++++++++++++++ 4 files changed, 657 insertions(+), 1 deletion(-) create mode 100644 datafusion/ffi/src/udaf/accumulator.rs create mode 100644 datafusion/ffi/src/udaf/mod.rs diff --git a/datafusion/ffi/src/arrow_wrappers.rs b/datafusion/ffi/src/arrow_wrappers.rs index eb1f34b3d93a..547dd0156d9b 100644 --- a/datafusion/ffi/src/arrow_wrappers.rs +++ b/datafusion/ffi/src/arrow_wrappers.rs @@ -21,7 +21,7 @@ use abi_stable::StableAbi; use arrow::{ array::{make_array, ArrayRef}, datatypes::{Schema, SchemaRef}, - ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, + ffi::{from_ffi, to_ffi, FFI_ArrowArray, FFI_ArrowSchema}, }; use log::error; @@ -79,3 +79,14 @@ impl TryFrom for ArrayRef { Ok(make_array(data)) } } + +impl TryFrom<&ArrayRef> for WrappedArray { + type Error = arrow::error::ArrowError; + + fn try_from(array: &ArrayRef) -> Result { + let (array, schema) = to_ffi(&array.to_data())?; + let schema = WrappedSchema(schema); + + Ok(WrappedArray { array, schema }) + } +} diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs index d877e182a1d8..755c460f3133 100644 --- a/datafusion/ffi/src/lib.rs +++ b/datafusion/ffi/src/lib.rs @@ -34,6 +34,7 @@ pub mod schema_provider; pub mod session_config; pub mod table_provider; pub mod table_source; +pub mod udaf; pub mod udf; pub mod udtf; pub mod util; diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs new file mode 100644 index 000000000000..b46c9a479f35 --- /dev/null +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -0,0 +1,342 @@ +use std::{ + ffi::c_void, + sync::{Arc, Mutex}, +}; + +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use arrow::{array::ArrayRef, error::ArrowError}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::Accumulator, + scalar::ScalarValue, +}; +use prost::Message; + +use crate::{ + arrow_wrappers::WrappedArray, + df_result, rresult, rresult_return, +}; + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_Accumulator { + pub update_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + ) -> RResult<(), RString>, + + // Evaluate and return a ScalarValues as protobuf bytes + pub evaluate: unsafe extern "C" fn(accumulator: &Self) -> RResult, RString>, + + pub size: unsafe extern "C" fn(accumulator: &Self) -> usize, + + pub state: + unsafe extern "C" fn(accumulator: &Self) -> RResult>, RString>, + + pub merge_batch: unsafe extern "C" fn( + accumulator: &mut Self, + states: RVec, + ) -> RResult<(), RString>, + + pub retract_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + ) -> RResult<(), RString>, + + pub supports_retract_batch: unsafe extern "C" fn(accumulator: &Self) -> bool, + + /// Used to create a clone on the provider of the accumulator. This should + /// only need to be called by the receiver of the accumulator. + pub clone: unsafe extern "C" fn(accumulator: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(accumulator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the accumulator. + /// A [`ForeignAccumulator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_Accumulator {} +unsafe impl Sync for FFI_Accumulator {} + +pub struct AccumulatorPrivateData { + pub accumulator: Arc>, +} + +unsafe extern "C" fn update_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, + values: RVec, +) -> RResult<(), RString> { + let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + let mut accumulator = rresult_return!(accum_data + .accumulator + .lock() + .map_err(|e| DataFusionError::Execution(e.to_string()))); + + rresult!(accumulator.update_batch(&values_arrays)) +} + +unsafe extern "C" fn evaluate_fn_wrapper( + accumulator: &FFI_Accumulator, +) -> RResult, RString> { + let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + let accum_data = &mut (*private_data); + let mut accumulator_internal = rresult_return!(accum_data + .accumulator + .lock() + .map_err(|_| DataFusionError::Execution( + "Unable to aquire lock on FFI Accumulator".to_string() + ))); + + let scalar_result = rresult_return!(accumulator_internal.evaluate()); + let proto_result: datafusion_proto::protobuf::ScalarValue = + rresult_return!((&scalar_result).try_into()); + + RResult::ROk(proto_result.encode_to_vec().into()) +} + +unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_Accumulator) -> usize { + let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + let accum_data = &mut (*private_data); + + accum_data + .accumulator + .lock() + .map(|accum| accum.size()) + .unwrap_or_default() +} + +unsafe extern "C" fn state_fn_wrapper( + accumulator: &FFI_Accumulator, +) -> RResult>, RString> { + let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + let accum_data = &mut (*private_data); + let mut accumulator_internal = rresult_return!(accum_data + .accumulator + .lock() + .map_err(|_| DataFusionError::Execution( + "Unable to aquire lock on FFI Accumulator".to_string() + ))); + + let state = rresult_return!(accumulator_internal.state()); + let state = state + .into_iter() + .map(|state_val| { + datafusion_proto::protobuf::ScalarValue::try_from(&state_val) + .map_err(DataFusionError::from) + .map(|v| RVec::from(v.encode_to_vec())) + }) + .collect::>>() + .map(|state_vec| state_vec.into()); + + rresult!(state) +} + +unsafe extern "C" fn merge_batch_fn_wrapper(accumulator: &mut FFI_Accumulator, + states: RVec, +) -> RResult<(), RString> { + let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + let accum_data = &mut (*private_data); + let mut accumulator_internal = rresult_return!(accum_data + .accumulator + .lock() + .map_err(|_| DataFusionError::Execution( + "Unable to aquire lock on FFI Accumulator".to_string() + ))); + + let states = rresult_return!(states.into_iter() + .map(|state| ArrayRef::try_from(state).map_err(DataFusionError::from)) + .collect::>>()); + + rresult!(accumulator_internal.merge_batch(&states)) +} + +unsafe extern "C" fn retract_batch_fn_wrapper(accumulator: &mut FFI_Accumulator, + values: RVec, +) -> RResult<(), RString> { + let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + let accum_data = &mut (*private_data); + let mut accumulator_internal = rresult_return!(accum_data + .accumulator + .lock() + .map_err(|_| DataFusionError::Execution( + "Unable to aquire lock on FFI Accumulator".to_string() + ))); + + let values = rresult_return!(values.into_iter() + .map(|state| ArrayRef::try_from(state).map_err(DataFusionError::from)) + .collect::>>()); + + rresult!(accumulator_internal.retract_batch(&values)) +} + +unsafe extern "C" fn supports_retract_batch_fn_wrapper(accumulator: &FFI_Accumulator, +) -> bool { + let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + let accum_data = &mut (*private_data); + accum_data + .accumulator + .lock() + .map(|accum| accum.supports_retract_batch()) + .unwrap_or(false) +} + +unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) { + let private_data = + Box::from_raw(accumulator.private_data as *mut AccumulatorPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(accumulator: &FFI_Accumulator) -> FFI_Accumulator { + let private_data = accumulator.private_data as *const AccumulatorPrivateData; + let accum_data = &(*private_data); + + Arc::clone(&accum_data.accumulator).into() +} + +impl Clone for FFI_Accumulator { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl From>> for FFI_Accumulator { + fn from(accumulator: Arc>) -> Self { + let private_data = Box::new(AccumulatorPrivateData { accumulator }); + + Self { + update_batch: update_batch_fn_wrapper, + evaluate: evaluate_fn_wrapper, + size: size_fn_wrapper, + state: state_fn_wrapper, + merge_batch: merge_batch_fn_wrapper, + retract_batch: retract_batch_fn_wrapper, + supports_retract_batch: supports_retract_batch_fn_wrapper, + + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_Accumulator { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignAccumulator is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_Accumulator. +#[derive(Debug)] +pub struct ForeignAccumulator { + accumulator: FFI_Accumulator, +} + +unsafe impl Send for ForeignAccumulator {} +unsafe impl Sync for ForeignAccumulator {} + +impl From<&FFI_Accumulator> for ForeignAccumulator { + fn from(accumulator: &FFI_Accumulator) -> Self { + Self { + accumulator: accumulator.clone(), + } + } +} + +impl Accumulator for ForeignAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.update_batch)( + &mut self.accumulator, + values.into() + )) + } + } + + fn evaluate(&mut self) -> Result { + unsafe { + let scalar_bytes = + df_result!((self.accumulator.evaluate)(&self.accumulator))?; + + let proto_scalar = + datafusion_proto::protobuf::ScalarValue::decode(scalar_bytes.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e)))?; + + ScalarValue::try_from(&proto_scalar).map_err(DataFusionError::from) + } + } + + fn size(&self) -> usize { + unsafe { (self.accumulator.size)(&self.accumulator) } + } + + fn state(&mut self) -> Result> { + unsafe { + let state_protos = df_result!((self.accumulator.state)(&self.accumulator))?; + + state_protos + .into_iter() + .map(|proto_bytes| { + datafusion_proto::protobuf::ScalarValue::decode(proto_bytes.as_ref()) + .map_err(|e| DataFusionError::External(Box::new(e))) + .and_then(|proto_value| { + ScalarValue::try_from(&proto_value) + .map_err(DataFusionError::from) + }) + }) + .collect::>>() + } + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + unsafe { + let states = states + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.merge_batch)( + &mut self.accumulator, + states.into() + )) + } + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + df_result!((self.accumulator.retract_batch)( + &mut self.accumulator, + values.into() + )) + } + } + + fn supports_retract_batch(&self) -> bool { + unsafe { (self.accumulator.supports_retract_batch)(&self.accumulator) } + } +} diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs new file mode 100644 index 000000000000..1c8794f96f0e --- /dev/null +++ b/datafusion/ffi/src/udaf/mod.rs @@ -0,0 +1,302 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; +use arrow::datatypes::DataType; +use arrow::ffi::{from_ffi, to_ffi, FFI_ArrowSchema}; +use datafusion::{ + error::DataFusionError, + logical_expr::{ + function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, + utils::AggregateOrderSensitivity, + Accumulator, GroupsAccumulator, ReversedUDAF, + }, +}; +use datafusion::{ + error::Result, + logical_expr::{ + AggregateUDF, AggregateUDFImpl, ColumnarValue, ScalarFunctionArgs, Signature, + }, +}; + +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, + signature::{self, rvec_wrapped_to_vec_datatype, FFI_Signature}, +}; + +mod accumulator; + +/// A stable struct for sharing a [`AggregateUDF`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_AggregateUDF { + /// Return the udaf name. + pub name: RString, + + pub signature: unsafe extern "C" fn(udaf: &Self) -> RResult, + + pub aliases: unsafe extern "C" fn(udaf: &Self) -> RVec, + + pub return_type: unsafe extern "C" fn( + udaf: &Self, + arg_types: RVec, + ) -> RResult, + + pub is_nullable: bool, + + + /// Used to create a clone on the provider of the udaf. This should + /// only need to be called by the receiver of the udaf. + pub clone: unsafe extern "C" fn(udaf: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(udaf: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the udaf. + /// A [`ForeignAggregateUDF`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_AggregateUDF {} +unsafe impl Sync for FFI_AggregateUDF {} + +pub struct AggregateUDFPrivateData { + pub udaf: Arc, +} + +unsafe extern "C" fn name_fn_wrapper(udaf: &FFI_AggregateUDF) -> RString { + let private_data = udaf.private_data as *const AggregateUDFPrivateData; + let udaf = &(*private_data).udaf; + + udaf.name().into() +} + +unsafe extern "C" fn signature_fn_wrapper( + udaf: &FFI_AggregateUDF, +) -> RResult { + let private_data = udaf.private_data as *const AggregateUDFPrivateData; + let udaf = &(*private_data).udaf; + + rresult!(udaf.signature().try_into()) +} + +unsafe extern "C" fn aliases_fn_wrapper(udaf: &FFI_AggregateUDF) -> RVec { + let private_data = udaf.private_data as *const AggregateUDFPrivateData; + let udaf = &(*private_data).udaf; + + udaf.aliases().iter().map(|s| s.to_owned().into()).collect() +} + +unsafe extern "C" fn return_type_fn_wrapper( + udaf: &FFI_AggregateUDF, + arg_types: RVec, +) -> RResult { + let private_data = udaf.private_data as *const AggregateUDFPrivateData; + let udaf = &(*private_data).udaf; + + let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); + + let return_type = udaf + .return_type(&arg_types) + .and_then(|v| FFI_ArrowSchema::try_from(v).map_err(DataFusionError::from)) + .map(WrappedSchema); + + rresult!(return_type) +} + +unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) { + let private_data = Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF { + let private_data = udaf.private_data as *const AggregateUDFPrivateData; + let udaf_data = &(*private_data); + + Arc::clone(&udaf_data.udaf).into() +} + +impl Clone for FFI_AggregateUDF { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl From> for FFI_AggregateUDF { + fn from(udaf: Arc) -> Self { + let name = udaf.name().into(); + let is_nullable = udaf.is_nullable(); + + let private_data = Box::new(AggregateUDFPrivateData { udaf }); + + Self { + name, + is_nullable, + signature: signature_fn_wrapper, + aliases: aliases_fn_wrapper, + return_type: return_type_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_AggregateUDF { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignAggregateUDF is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_AggregateUDF. +#[derive(Debug)] +pub struct ForeignAggregateUDF { + signature: Signature, + aliases: Vec, + udaf: FFI_AggregateUDF, +} + +unsafe impl Send for ForeignAggregateUDF {} +unsafe impl Sync for ForeignAggregateUDF {} + +impl TryFrom<&FFI_AggregateUDF> for ForeignAggregateUDF { + type Error = DataFusionError; + + fn try_from(udaf: &FFI_AggregateUDF) -> Result { + unsafe { + let ffi_signature = df_result!((udaf.signature)(udaf))?; + let signature = (&ffi_signature).try_into()?; + + let aliases = (udaf.aliases)(udaf) + .into_iter() + .map(|s| s.to_string()) + .collect(); + + Ok(Self { + udaf: udaf.clone(), + signature, + aliases, + }) + } + } +} + +impl AggregateUDFImpl for ForeignAggregateUDF { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + self.udaf.name.as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + let arg_types = signature::vec_datatype_to_rvec_wrapped(arg_types)?; + + let result = unsafe { (self.udaf.return_type)(&self.udaf, arg_types) }; + + let result = df_result!(result); + + result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) + } + + fn is_nullable(&self) -> bool { + self.udaf.is_nullable + } + + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> {} + + fn state_fields(&self, args: StateFieldsArgs) -> Result> {} + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {} + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn create_sliding_accumulator( + &self, + args: AccumulatorArgs, + ) -> Result> { + } + + fn with_beneficial_ordering( + self: Arc, + _beneficial_ordering: bool, + ) -> Result>> { + } + + fn order_sensitivity(&self) -> AggregateOrderSensitivity {} + + fn simplify(&self) -> Option {} + + fn reverse_expr(&self) -> ReversedUDAF {} + + fn coerce_types(&self, _arg_types: &[DataType]) -> Result> {} + + fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {} + + fn is_descending(&self) -> Option {} + + fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option { + } + + fn default_value(&self, data_type: &DataType) -> Result {} + + fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {} +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_round_trip_udaf() -> Result<()> { + let original_udaf = datafusion::functions::math::abs::AbsFunc::new(); + let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + + let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + + let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; + + assert!(original_udaf.name() == foreign_udaf.name()); + + Ok(()) + } +} From efa2e5c0da9cb93ca87e36470cf3cbf04fe6eafe Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 18 Feb 2025 16:25:57 -0500 Subject: [PATCH 02/32] Intermediate work. Going through groups accumulator --- Cargo.lock | 1 + datafusion/ffi/Cargo.toml | 1 + datafusion/ffi/src/udaf/accumulator.rs | 121 ++--- datafusion/ffi/src/udaf/accumulator_args.rs | 108 +++++ datafusion/ffi/src/udaf/groups_accumulator.rs | 445 ++++++++++++++++++ datafusion/ffi/src/udaf/mod.rs | 163 ++++++- 6 files changed, 750 insertions(+), 89 deletions(-) create mode 100644 datafusion/ffi/src/udaf/accumulator_args.rs create mode 100644 datafusion/ffi/src/udaf/groups_accumulator.rs diff --git a/Cargo.lock b/Cargo.lock index 1934063413ed..d0a125cbdbbd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2243,6 +2243,7 @@ dependencies = [ "async-trait", "datafusion", "datafusion-proto", + "datafusion-proto-common", "doc-comment", "futures", "log", diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index 29f40df51444..2dadc2e5d541 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -45,6 +45,7 @@ async-ffi = { version = "0.5.0", features = ["abi_stable"] } async-trait = { workspace = true } datafusion = { workspace = true, default-features = false } datafusion-proto = { workspace = true } +datafusion-proto-common = { workspace = true } futures = { workspace = true } log = { workspace = true } prost = { workspace = true } diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs index b46c9a479f35..4b709ff1a3c3 100644 --- a/datafusion/ffi/src/udaf/accumulator.rs +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -15,10 +15,7 @@ use datafusion::{ }; use prost::Message; -use crate::{ - arrow_wrappers::WrappedArray, - df_result, rresult, rresult_return, -}; +use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; #[repr(C)] #[derive(Debug, StableAbi)] @@ -51,7 +48,7 @@ pub struct FFI_Accumulator { /// Used to create a clone on the provider of the accumulator. This should /// only need to be called by the receiver of the accumulator. - pub clone: unsafe extern "C" fn(accumulator: &Self) -> Self, + // pub clone: unsafe extern "C" fn(accumulator: &Self) -> Self, /// Release the memory of the private data when it is no longer being used. pub release: unsafe extern "C" fn(accumulator: &mut Self), @@ -65,7 +62,7 @@ unsafe impl Send for FFI_Accumulator {} unsafe impl Sync for FFI_Accumulator {} pub struct AccumulatorPrivateData { - pub accumulator: Arc>, + pub accumulator: Box, } unsafe extern "C" fn update_batch_fn_wrapper( @@ -81,12 +78,7 @@ unsafe extern "C" fn update_batch_fn_wrapper( .collect::>>(); let values_arrays = rresult_return!(values_arrays); - let mut accumulator = rresult_return!(accum_data - .accumulator - .lock() - .map_err(|e| DataFusionError::Execution(e.to_string()))); - - rresult!(accumulator.update_batch(&values_arrays)) + rresult!(accum_data.accumulator.update_batch(&values_arrays)) } unsafe extern "C" fn evaluate_fn_wrapper( @@ -94,14 +86,8 @@ unsafe extern "C" fn evaluate_fn_wrapper( ) -> RResult, RString> { let private_data = accumulator.private_data as *mut AccumulatorPrivateData; let accum_data = &mut (*private_data); - let mut accumulator_internal = rresult_return!(accum_data - .accumulator - .lock() - .map_err(|_| DataFusionError::Execution( - "Unable to aquire lock on FFI Accumulator".to_string() - ))); - - let scalar_result = rresult_return!(accumulator_internal.evaluate()); + + let scalar_result = rresult_return!(accum_data.accumulator.evaluate()); let proto_result: datafusion_proto::protobuf::ScalarValue = rresult_return!((&scalar_result).try_into()); @@ -112,11 +98,7 @@ unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_Accumulator) -> usize { let private_data = accumulator.private_data as *mut AccumulatorPrivateData; let accum_data = &mut (*private_data); - accum_data - .accumulator - .lock() - .map(|accum| accum.size()) - .unwrap_or_default() + accum_data.accumulator.size() } unsafe extern "C" fn state_fn_wrapper( @@ -124,14 +106,8 @@ unsafe extern "C" fn state_fn_wrapper( ) -> RResult>, RString> { let private_data = accumulator.private_data as *mut AccumulatorPrivateData; let accum_data = &mut (*private_data); - let mut accumulator_internal = rresult_return!(accum_data - .accumulator - .lock() - .map_err(|_| DataFusionError::Execution( - "Unable to aquire lock on FFI Accumulator".to_string() - ))); - - let state = rresult_return!(accumulator_internal.state()); + + let state = rresult_return!(accum_data.accumulator.state()); let state = state .into_iter() .map(|state_val| { @@ -145,53 +121,42 @@ unsafe extern "C" fn state_fn_wrapper( rresult!(state) } -unsafe extern "C" fn merge_batch_fn_wrapper(accumulator: &mut FFI_Accumulator, +unsafe extern "C" fn merge_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, states: RVec, ) -> RResult<(), RString> { let private_data = accumulator.private_data as *mut AccumulatorPrivateData; let accum_data = &mut (*private_data); - let mut accumulator_internal = rresult_return!(accum_data - .accumulator - .lock() - .map_err(|_| DataFusionError::Execution( - "Unable to aquire lock on FFI Accumulator".to_string() - ))); - - let states = rresult_return!(states.into_iter() + + let states = rresult_return!(states + .into_iter() .map(|state| ArrayRef::try_from(state).map_err(DataFusionError::from)) .collect::>>()); - rresult!(accumulator_internal.merge_batch(&states)) + rresult!(accum_data.accumulator.merge_batch(&states)) } -unsafe extern "C" fn retract_batch_fn_wrapper(accumulator: &mut FFI_Accumulator, +unsafe extern "C" fn retract_batch_fn_wrapper( + accumulator: &mut FFI_Accumulator, values: RVec, ) -> RResult<(), RString> { let private_data = accumulator.private_data as *mut AccumulatorPrivateData; let accum_data = &mut (*private_data); - let mut accumulator_internal = rresult_return!(accum_data - .accumulator - .lock() - .map_err(|_| DataFusionError::Execution( - "Unable to aquire lock on FFI Accumulator".to_string() - ))); - - let values = rresult_return!(values.into_iter() + + let values = rresult_return!(values + .into_iter() .map(|state| ArrayRef::try_from(state).map_err(DataFusionError::from)) .collect::>>()); - rresult!(accumulator_internal.retract_batch(&values)) + rresult!(accum_data.accumulator.retract_batch(&values)) } -unsafe extern "C" fn supports_retract_batch_fn_wrapper(accumulator: &FFI_Accumulator, +unsafe extern "C" fn supports_retract_batch_fn_wrapper( + accumulator: &FFI_Accumulator, ) -> bool { let private_data = accumulator.private_data as *mut AccumulatorPrivateData; let accum_data = &mut (*private_data); - accum_data - .accumulator - .lock() - .map(|accum| accum.supports_retract_batch()) - .unwrap_or(false) + accum_data.accumulator.supports_retract_batch() } unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) { @@ -200,23 +165,21 @@ unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) { drop(private_data); } -unsafe extern "C" fn clone_fn_wrapper(accumulator: &FFI_Accumulator) -> FFI_Accumulator { - let private_data = accumulator.private_data as *const AccumulatorPrivateData; - let accum_data = &(*private_data); - - Arc::clone(&accum_data.accumulator).into() -} +// unsafe extern "C" fn clone_fn_wrapper(accumulator: &FFI_Accumulator) -> FFI_Accumulator { +// let private_data = accumulator.private_data as *const AccumulatorPrivateData; +// let accum_data = &(*private_data); -impl Clone for FFI_Accumulator { - fn clone(&self) -> Self { - unsafe { (self.clone)(self) } - } -} +// Box::new(accum_data.accumulator).into() +// } -impl From>> for FFI_Accumulator { - fn from(accumulator: Arc>) -> Self { - let private_data = Box::new(AccumulatorPrivateData { accumulator }); +// impl Clone for FFI_Accumulator { +// fn clone(&self) -> Self { +// unsafe { (self.clone)(self) } +// } +// } +impl From> for FFI_Accumulator { + fn from(accumulator: Box) -> Self { Self { update_batch: update_batch_fn_wrapper, evaluate: evaluate_fn_wrapper, @@ -226,9 +189,9 @@ impl From>> for FFI_Accumulator { retract_batch: retract_batch_fn_wrapper, supports_retract_batch: supports_retract_batch_fn_wrapper, - clone: clone_fn_wrapper, + // clone: clone_fn_wrapper, release: release_fn_wrapper, - private_data: Box::into_raw(private_data) as *mut c_void, + private_data: Box::into_raw(accumulator) as *mut c_void, } } } @@ -253,11 +216,9 @@ pub struct ForeignAccumulator { unsafe impl Send for ForeignAccumulator {} unsafe impl Sync for ForeignAccumulator {} -impl From<&FFI_Accumulator> for ForeignAccumulator { - fn from(accumulator: &FFI_Accumulator) -> Self { - Self { - accumulator: accumulator.clone(), - } +impl From for ForeignAccumulator { + fn from(accumulator: FFI_Accumulator) -> Self { + Self { accumulator } } } diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs new file mode 100644 index 000000000000..2eeccca56c96 --- /dev/null +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -0,0 +1,108 @@ +use std::sync::Arc; + +use abi_stable::{ + std_types::{RString, RVec}, + StableAbi, +}; +use arrow::{datatypes::Schema, ffi::FFI_ArrowSchema}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::function::AccumulatorArgs, + prelude::SessionContext, +}; +use datafusion_proto::{ + physical_plan::{ + from_proto::{parse_physical_exprs, parse_physical_sort_exprs}, + to_proto::{serialize_physical_exprs, serialize_physical_sort_exprs}, + DefaultPhysicalExtensionCodec, + }, + protobuf::PhysicalAggregateExprNode, +}; +use prost::Message; + +use crate::{arrow_wrappers::WrappedSchema, rresult_return}; + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_AccumulatorArgs { + return_type: WrappedSchema, + schema: WrappedSchema, + is_reversed: bool, + name: RString, + physical_expr_def: RVec, +} + +impl FFI_AccumulatorArgs { + pub fn to_accumulator_args(&self) -> Result { + let proto_def = + PhysicalAggregateExprNode::decode(self.physical_expr_def.as_ref()) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + + let return_type = &(&self.return_type.0).try_into()?; + let schema = &Arc::new(Schema::try_from(&self.schema.0)?); + + let default_ctx = SessionContext::new(); + let codex = DefaultPhysicalExtensionCodec {}; + + // let proto_ordering_req = + // rresult_return!(PhysicalSortExprNodeCollection::decode(ordering_req.as_ref())); + let ordering_req = &parse_physical_sort_exprs( + &proto_def.ordering_req, + &default_ctx, + &schema, + &codex, + )?; + + let exprs = &rresult_return!(parse_physical_exprs( + &proto_def.expr, + &default_ctx, + &schema, + &codex + )); + + Ok(AccumulatorArgs { + return_type, + schema, + ignore_nulls: proto_def.ignore_nulls, + ordering_req, + is_reversed: self.is_reversed, + name: self.name.as_str(), + is_distinct: proto_def.distinct, + exprs, + }) + } +} + +impl<'a> TryFrom> for FFI_AccumulatorArgs { + type Error = DataFusionError; + + fn try_from(args: AccumulatorArgs) -> std::result::Result { + let return_type = WrappedSchema(FFI_ArrowSchema::try_from(args.return_type)?); + let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?); + + let codec = DefaultPhysicalExtensionCodec {}; + let ordering_req = + serialize_physical_sort_exprs(args.ordering_req.to_owned(), &codec)?; + + let expr = serialize_physical_exprs(args.exprs, &codec)?; + + let physical_expr_def = PhysicalAggregateExprNode { + expr, + ordering_req, + distinct: args.is_distinct, + ignore_nulls: args.ignore_nulls, + fun_definition: None, + aggregate_function: None, + }; + let physical_expr_def = physical_expr_def.encode_to_vec().into(); + + Ok(Self { + return_type, + schema, + is_reversed: args.is_reversed, + name: args.name.into(), + physical_expr_def, + }) + } +} diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs new file mode 100644 index 000000000000..0c5eb9475d78 --- /dev/null +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -0,0 +1,445 @@ +use std::{ + ffi::c_void, + sync::{Arc, Mutex}, +}; + +use abi_stable::{ + std_types::{ROption, RResult, RString, RVec}, + StableAbi, +}; +use arrow::{ + array::{Array, ArrayRef, BooleanArray}, + error::ArrowError, + ffi::{from_ffi, to_ffi, FFI_ArrowArray}, +}; +use datafusion::{ + error::{DataFusionError, Result}, + logical_expr::{EmitTo, GroupsAccumulator}, + scalar::ScalarValue, +}; +use prost::Message; + +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, +}; + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_GroupsAccumulator { + pub update_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, + ) -> RResult<(), RString>, + + // Evaluate and return a ScalarValues as protobuf bytes + pub evaluate: unsafe extern "C" fn( + accumulator: &Self, + emit_to: FFI_EmitTo, + ) -> RResult, + + pub size: unsafe extern "C" fn(accumulator: &Self) -> usize, + + pub state: unsafe extern "C" fn( + accumulator: &Self, + emit_to: FFI_EmitTo, + ) -> RResult, RString>, + + pub merge_batch: unsafe extern "C" fn( + accumulator: &mut Self, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, + ) -> RResult<(), RString>, + + pub convert_to_state: unsafe extern "C" fn( + accumulator: &Self, + values: RVec, + opt_filter: ROption, + ) + -> RResult, RString>, + + pub supports_convert_to_state: bool, + + /// Used to create a clone on the provider of the accumulator. This should + /// only need to be called by the receiver of the accumulator. + // pub clone: unsafe extern "C" fn(accumulator: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(accumulator: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the accumulator. + /// A [`ForeignGroupsAccumulator`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_GroupsAccumulator {} +unsafe impl Sync for FFI_GroupsAccumulator {} + +pub struct GroupsAccumulatorPrivateData { + pub accumulator: Box, +} + +unsafe extern "C" fn update_batch_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, +) -> RResult<(), RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + let group_indices: Vec = group_indices.into_iter().collect(); + + let maybe_filter = opt_filter.into_option().and_then(|filter| { + match ArrayRef::try_from(filter) { + Ok(v) => Some(v), + Err(e) => { + log::warn!("Error during FFI array conversion. Ignoring optional filter in groups accumulator. {}", e); + None + } + } + }).map(|arr| arr.into_data()); + let opt_filter = maybe_filter.map(|arr| BooleanArray::from(arr)); + + rresult!(accum_data.accumulator.update_batch( + &values_arrays, + &group_indices, + opt_filter.as_ref(), + total_num_groups + )) +} + +unsafe extern "C" fn evaluate_fn_wrapper( + accumulator: &FFI_GroupsAccumulator, + emit_to: FFI_EmitTo, +) -> RResult { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let result = rresult_return!(accum_data.accumulator.evaluate(emit_to.into())); + + rresult!(WrappedArray::try_from(&result)) +} + +unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> usize { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + accum_data.accumulator.size() +} + +unsafe extern "C" fn state_fn_wrapper( + accumulator: &FFI_GroupsAccumulator, + emit_to: FFI_EmitTo, +) -> RResult, RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let state = rresult_return!(accum_data.accumulator.state(emit_to.into())); + rresult!(state + .into_iter() + .map(|arr| WrappedArray::try_from(&arr).map_err(DataFusionError::from)) + .collect::>>()) +} + +unsafe extern "C" fn merge_batch_fn_wrapper( + accumulator: &mut FFI_GroupsAccumulator, + values: RVec, + group_indices: RVec, + opt_filter: ROption, + total_num_groups: usize, +) -> RResult<(), RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + let values_arrays = values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); + + let group_indices: Vec = group_indices.into_iter().collect(); + + let maybe_filter = opt_filter.into_option().and_then(|filter| { + match ArrayRef::try_from(filter) { + Ok(v) => Some(v), + Err(e) => { + log::warn!("Error during FFI array conversion. Ignoring optional filter in groups accumulator. {}", e); + None + } + } + }).map(|arr| arr.into_data()); + let opt_filter = maybe_filter.map(|arr| BooleanArray::from(arr)); + + rresult!(accum_data.accumulator.merge_batch( + &values_arrays, + &group_indices, + opt_filter.as_ref(), + total_num_groups + )) +} + +unsafe extern "C" fn convert_to_state_fn_wrapper( + accumulator: &FFI_GroupsAccumulator, + values: RVec, + opt_filter: ROption, +) -> RResult, RString> { + let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; + let accum_data = &mut (*private_data); + + let values = rresult_return!(values + .into_iter() + .map(|v| ArrayRef::try_from(v).map_err(DataFusionError::from)) + .collect::>>()); + + let opt_filter = opt_filter.into_option().and_then(|filter| { + match ArrayRef::try_from(filter) { + Ok(v) => Some(v), + Err(e) => { + log::warn!("Error during FFI array conversion. Ignoring optional filter in groups accumulator. {}", e); + None + } + } + }).map(|arr| arr.into_data()).map(|arr| BooleanArray::from(arr)); + + let state = rresult_return!(accum_data + .accumulator + .convert_to_state(&values, opt_filter.as_ref())); + + rresult!(state + .iter() + .map(|arr| WrappedArray::try_from(arr).map_err(DataFusionError::from)) + .collect::>>()) +} + +unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_GroupsAccumulator) { + let private_data = + Box::from_raw(accumulator.private_data as *mut GroupsAccumulatorPrivateData); + drop(private_data); +} + +// unsafe extern "C" fn clone_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> FFI_GroupsAccumulator { +// let private_data = accumulator.private_data as *const GroupsAccumulatorPrivateData; +// let accum_data = &(*private_data); + +// Box::new(accum_data.accumulator).into() +// } + +// impl Clone for FFI_GroupsAccumulator { +// fn clone(&self) -> Self { +// unsafe { (self.clone)(self) } +// } +// } + +impl From> for FFI_GroupsAccumulator { + fn from(accumulator: Box) -> Self { + Self { + update_batch: update_batch_fn_wrapper, + evaluate: evaluate_fn_wrapper, + size: size_fn_wrapper, + state: state_fn_wrapper, + merge_batch: merge_batch_fn_wrapper, + convert_to_state: convert_to_state_fn_wrapper, + supports_convert_to_state: accumulator.supports_convert_to_state(), + + release: release_fn_wrapper, + private_data: Box::into_raw(accumulator) as *mut c_void, + } + } +} + +impl Drop for FFI_GroupsAccumulator { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignGroupsAccumulator is to be used by the caller of the UDF, so it has +/// no knowledge or access to the private data. All interaction with the UDF +/// must occur through the functions defined in FFI_GroupsAccumulator. +#[derive(Debug)] +pub struct ForeignGroupsAccumulator { + accumulator: FFI_GroupsAccumulator, +} + +unsafe impl Send for ForeignGroupsAccumulator {} +unsafe impl Sync for ForeignGroupsAccumulator {} + +impl From for ForeignGroupsAccumulator { + fn from(accumulator: FFI_GroupsAccumulator) -> Self { + Self { accumulator } + } +} + +impl GroupsAccumulator for ForeignGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + let group_indices = group_indices.iter().cloned().collect(); + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + df_result!((self.accumulator.update_batch)( + &mut self.accumulator, + values.into(), + group_indices, + opt_filter, + total_num_groups + )) + } + } + + fn size(&self) -> usize { + unsafe { (self.accumulator.size)(&self.accumulator) } + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + unsafe { + let return_array = df_result!((self.accumulator.evaluate)( + &self.accumulator, + emit_to.into() + ))?; + + return_array.try_into().map_err(DataFusionError::from) + } + } + + fn state(&mut self, emit_to: EmitTo) -> Result> { + unsafe { + let returned_arrays = + df_result!((self.accumulator.state)(&self.accumulator, emit_to.into()))?; + + returned_arrays + .into_iter() + .map(|wrapped_array| { + wrapped_array.try_into().map_err(DataFusionError::from) + }) + .collect::>>() + } + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + let group_indices = group_indices.iter().cloned().collect(); + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + df_result!((self.accumulator.merge_batch)( + &mut self.accumulator, + values.into(), + group_indices, + opt_filter, + total_num_groups + )) + } + } + + fn convert_to_state( + &self, + values: &[ArrayRef], + opt_filter: Option<&BooleanArray>, + ) -> Result> { + unsafe { + let values = values + .iter() + .map(WrappedArray::try_from) + .collect::, ArrowError>>()?; + + let opt_filter = opt_filter + .map(|bool_array| to_ffi(&bool_array.to_data())) + .transpose()? + .map(|(array, schema)| WrappedArray { + array, + schema: WrappedSchema(schema), + }) + .into(); + + let returned_array = df_result!((self.accumulator.convert_to_state)( + &self.accumulator, + values, + opt_filter + ))?; + + returned_array + .into_iter() + .map(|arr| arr.try_into().map_err(DataFusionError::from)) + .collect() + } + } + + fn supports_convert_to_state(&self) -> bool { + self.accumulator.supports_convert_to_state + } +} + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_EmitTo { + All, + First(usize), +} + +impl From for FFI_EmitTo { + fn from(value: EmitTo) -> Self { + match value { + EmitTo::All => Self::All, + EmitTo::First(v) => Self::First(v), + } + } +} + +impl From for EmitTo { + fn from(value: FFI_EmitTo) -> Self { + match value { + FFI_EmitTo::All => Self::All, + FFI_EmitTo::First(v) => Self::First(v), + } + } +} diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index 1c8794f96f0e..75b200b5d147 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -15,13 +15,18 @@ // specific language governing permissions and limitations // under the License. -use std::{ffi::c_void, sync::Arc}; +use std::{ + ffi::c_void, + sync::{Arc, Mutex}, +}; use abi_stable::{ - std_types::{RResult, RString, RVec}, + std_types::{RResult, RStr, RString, RVec}, StableAbi, }; -use arrow::datatypes::DataType; +use accumulator::FFI_Accumulator; +use accumulator_args::FFI_AccumulatorArgs; +use arrow::datatypes::{DataType, Field, SchemaRef}; use arrow::ffi::{from_ffi, to_ffi, FFI_ArrowSchema}; use datafusion::{ error::DataFusionError, @@ -30,6 +35,8 @@ use datafusion::{ utils::AggregateOrderSensitivity, Accumulator, GroupsAccumulator, ReversedUDAF, }, + physical_plan::aggregates::order, + prelude::SessionContext, }; use datafusion::{ error::Result, @@ -37,14 +44,31 @@ use datafusion::{ AggregateUDF, AggregateUDFImpl, ColumnarValue, ScalarFunctionArgs, Signature, }, }; +use datafusion_proto::{ + physical_plan::{ + from_proto::{parse_physical_exprs, parse_physical_sort_exprs}, + to_proto::{ + serialize_physical_expr, serialize_physical_exprs, + serialize_physical_sort_exprs, + }, + DefaultPhysicalExtensionCodec, + }, + protobuf::{PhysicalAggregateExprNode, PhysicalSortExprNodeCollection}, +}; +use groups_accumulator::FFI_GroupsAccumulator; use crate::{ arrow_wrappers::{WrappedArray, WrappedSchema}, df_result, rresult, rresult_return, - signature::{self, rvec_wrapped_to_vec_datatype, FFI_Signature}, + signature::{ + self, rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped, FFI_Signature, + }, }; +use prost::Message; mod accumulator; +mod accumulator_args; +mod groups_accumulator; /// A stable struct for sharing a [`AggregateUDF`] across FFI boundaries. #[repr(C)] @@ -65,6 +89,28 @@ pub struct FFI_AggregateUDF { pub is_nullable: bool, + pub groups_accumulator_supported: + unsafe extern "C" fn(udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs) -> bool, + + pub accumulator: unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, + + pub state_fields: unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + name: &RStr, + input_types: RVec, + return_type: WrappedSchema, + ordering_fields: RVec>, + is_distinct: bool, + ) -> RResult>, RString>, + + pub create_groups_accumulator: + unsafe extern "C" fn( + &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, /// Used to create a clone on the provider of the udaf. This should /// only need to be called by the receiver of the udaf. @@ -125,6 +171,49 @@ unsafe extern "C" fn return_type_fn_wrapper( rresult!(return_type) } +unsafe extern "C" fn accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let private_data = udaf.private_data as *const AggregateUDFPrivateData; + let udaf = &(*private_data).udaf; + + let accumulator_args = rresult_return!(args.to_accumulator_args()); + + rresult!(udaf + .accumulator(accumulator_args) + .map(FFI_Accumulator::from)) +} + +unsafe extern "C" fn create_groups_accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let private_data = udaf.private_data as *const AggregateUDFPrivateData; + let udaf = &(*private_data).udaf; + + let accumulator_args = rresult_return!(args.to_accumulator_args()); + + rresult!(udaf + .create_groups_accumulator(accumulator_args) + .map(FFI_GroupsAccumulator::from)) +} + +unsafe extern "C" fn groups_accumulator_supported_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> bool { + let private_data = udaf.private_data as *const AggregateUDFPrivateData; + let udaf = &(*private_data).udaf; + + args.to_accumulator_args() + .map(|a| udaf.groups_accumulator_supported(a)) + .unwrap_or_else(|e| { + log::warn!("Unable to parse accumulator args. {}", e); + false + }) +} + unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) { let private_data = Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData); drop(private_data); @@ -156,6 +245,9 @@ impl From> for FFI_AggregateUDF { signature: signature_fn_wrapper, aliases: aliases_fn_wrapper, return_type: return_type_fn_wrapper, + accumulator: accumulator_fn_wrapper, + create_groups_accumulator: create_groups_accumulator_fn_wrapper, + groups_accumulator_supported: groups_accumulator_supported_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, @@ -221,7 +313,7 @@ impl AggregateUDFImpl for ForeignAggregateUDF { } fn return_type(&self, arg_types: &[DataType]) -> Result { - let arg_types = signature::vec_datatype_to_rvec_wrapped(arg_types)?; + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; let result = unsafe { (self.udaf.return_type)(&self.udaf, arg_types) }; @@ -234,16 +326,69 @@ impl AggregateUDFImpl for ForeignAggregateUDF { self.udaf.is_nullable } - fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> {} + fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { + let args = acc_args.try_into()?; + + unsafe { df_result!((self.udaf.accumulator)(&self.udaf, args)) } + } + + fn state_fields(&self, args: StateFieldsArgs) -> Result> { + unsafe { + let name = RStr::from_str(args.name); + let input_types = vec_datatype_to_rvec_wrapped(args.input_types)?; + let return_type = WrappedSchema(FFI_ArrowSchema::try_from(args.return_type)?); + let ordering_fields = args + .ordering_fields + .iter() + .map(datafusion_proto::protobuf::Field::try_from) + .map(|v| v.map_err(DataFusionError::from)) + .collect::>>()? + .into_iter() + .map(|proto_field| proto_field.encode_to_vec().into()) + .collect(); + + let fields = df_result!((self.udaf.state_fields)( + &self.udaf, + &name, + input_types, + return_type, + ordering_fields, + args.is_distinct + ))?; + let fields = fields + .into_iter() + .map(|field_bytes| { + datafusion_proto_common::Field::decode(field_bytes.as_ref()) + .map_err(|e| DataFusionError::Execution(e.to_string())) + }) + .collect::>>()?; + + datafusion_proto_common::from_proto::parse_proto_fields_to_fields( + fields.iter(), + ) + .map_err(|e| DataFusionError::Execution(e.to_string())) + } + } - fn state_fields(&self, args: StateFieldsArgs) -> Result> {} + fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool { + let args = match FFI_AccumulatorArgs::try_from(args) { + Ok(v) => v, + Err(e) => { + log::warn!("Attempting to convert accumulator arguments: {}", e); + return false; + } + }; - fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool {} + unsafe { (self.udaf.groups_accumulator_supported)(&self.udaf, args) } + } fn create_groups_accumulator( &self, - _args: AccumulatorArgs, + args: AccumulatorArgs, ) -> Result> { + let args = FFI_AccumulatorArgs::try_from(args)?; + + unsafe { df_result!((self.udaf.accumulator)(&self.udaf, args)) } } fn aliases(&self) -> &[String] { From 53fafc9096ffdcecef1433f1ffac8b70c3e90089 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 19 Feb 2025 08:59:37 -0500 Subject: [PATCH 03/32] MVP for aggregate udf via FFI --- datafusion/ffi/src/udaf/accumulator.rs | 26 +- datafusion/ffi/src/udaf/accumulator_args.rs | 142 ++++++--- datafusion/ffi/src/udaf/groups_accumulator.rs | 36 ++- datafusion/ffi/src/udaf/mod.rs | 283 ++++++++++++------ 4 files changed, 333 insertions(+), 154 deletions(-) diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs index 4b709ff1a3c3..6ebe67a4b1c5 100644 --- a/datafusion/ffi/src/udaf/accumulator.rs +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -1,7 +1,21 @@ -use std::{ - ffi::c_void, - sync::{Arc, Mutex}, -}; +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ffi::c_void; use abi_stable::{ std_types::{RResult, RString, RVec}, @@ -46,10 +60,6 @@ pub struct FFI_Accumulator { pub supports_retract_batch: unsafe extern "C" fn(accumulator: &Self) -> bool, - /// Used to create a clone on the provider of the accumulator. This should - /// only need to be called by the receiver of the accumulator. - // pub clone: unsafe extern "C" fn(accumulator: &Self) -> Self, - /// Release the memory of the private data when it is no longer being used. pub release: unsafe extern "C" fn(accumulator: &mut Self), diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs index 2eeccca56c96..f15bed5aa2c2 100644 --- a/datafusion/ffi/src/udaf/accumulator_args.rs +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -1,14 +1,33 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use std::sync::Arc; use abi_stable::{ std_types::{RString, RVec}, StableAbi, }; -use arrow::{datatypes::Schema, ffi::FFI_ArrowSchema}; +use arrow::{ + datatypes::{DataType, Schema}, + ffi::FFI_ArrowSchema, +}; use datafusion::{ - error::{DataFusionError, Result}, - logical_expr::function::AccumulatorArgs, - prelude::SessionContext, + error::DataFusionError, logical_expr::function::AccumulatorArgs, + physical_expr::LexOrdering, physical_plan::PhysicalExpr, prelude::SessionContext, }; use datafusion_proto::{ physical_plan::{ @@ -20,7 +39,7 @@ use datafusion_proto::{ }; use prost::Message; -use crate::{arrow_wrappers::WrappedSchema, rresult_return}; +use crate::arrow_wrappers::WrappedSchema; #[repr(C)] #[derive(Debug, StableAbi)] @@ -33,51 +52,10 @@ pub struct FFI_AccumulatorArgs { physical_expr_def: RVec, } -impl FFI_AccumulatorArgs { - pub fn to_accumulator_args(&self) -> Result { - let proto_def = - PhysicalAggregateExprNode::decode(self.physical_expr_def.as_ref()) - .map_err(|e| DataFusionError::Execution(e.to_string()))?; - - let return_type = &(&self.return_type.0).try_into()?; - let schema = &Arc::new(Schema::try_from(&self.schema.0)?); - - let default_ctx = SessionContext::new(); - let codex = DefaultPhysicalExtensionCodec {}; - - // let proto_ordering_req = - // rresult_return!(PhysicalSortExprNodeCollection::decode(ordering_req.as_ref())); - let ordering_req = &parse_physical_sort_exprs( - &proto_def.ordering_req, - &default_ctx, - &schema, - &codex, - )?; - - let exprs = &rresult_return!(parse_physical_exprs( - &proto_def.expr, - &default_ctx, - &schema, - &codex - )); - - Ok(AccumulatorArgs { - return_type, - schema, - ignore_nulls: proto_def.ignore_nulls, - ordering_req, - is_reversed: self.is_reversed, - name: self.name.as_str(), - is_distinct: proto_def.distinct, - exprs, - }) - } -} - -impl<'a> TryFrom> for FFI_AccumulatorArgs { +impl TryFrom> for FFI_AccumulatorArgs { type Error = DataFusionError; - fn try_from(args: AccumulatorArgs) -> std::result::Result { + fn try_from(args: AccumulatorArgs) -> Result { let return_type = WrappedSchema(FFI_ArrowSchema::try_from(args.return_type)?); let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?); @@ -106,3 +84,71 @@ impl<'a> TryFrom> for FFI_AccumulatorArgs { }) } } + +/// This struct mirrors AccumulatorArgs except that it contains owned data. +/// It is necessary to create this struct so that we can parse the protobuf +/// data across the FFI boundary and turn it into owned data that +/// AccumulatorArgs can then reference. +pub struct ForeignAccumulatorArgs { + pub return_type: DataType, + pub schema: Schema, + pub ignore_nulls: bool, + pub ordering_req: LexOrdering, + pub is_reversed: bool, + pub name: String, + pub is_distinct: bool, + pub exprs: Vec>, +} + +impl TryFrom for ForeignAccumulatorArgs { + type Error = DataFusionError; + + fn try_from(value: FFI_AccumulatorArgs) -> Result { + let proto_def = + PhysicalAggregateExprNode::decode(value.physical_expr_def.as_ref()) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + + let return_type = (&value.return_type.0).try_into()?; + let schema = Schema::try_from(&value.schema.0)?; + + let default_ctx = SessionContext::new(); + let codex = DefaultPhysicalExtensionCodec {}; + + // let proto_ordering_req = + // rresult_return!(PhysicalSortExprNodeCollection::decode(ordering_req.as_ref())); + let ordering_req = parse_physical_sort_exprs( + &proto_def.ordering_req, + &default_ctx, + &schema, + &codex, + )?; + + let exprs = parse_physical_exprs(&proto_def.expr, &default_ctx, &schema, &codex)?; + + Ok(Self { + return_type, + schema, + ignore_nulls: proto_def.ignore_nulls, + ordering_req, + is_reversed: value.is_reversed, + name: value.name.to_string(), + is_distinct: proto_def.distinct, + exprs, + }) + } +} + +impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> { + fn from(value: &'a ForeignAccumulatorArgs) -> Self { + Self { + return_type: &value.return_type, + schema: &value.schema, + ignore_nulls: value.ignore_nulls, + ordering_req: &value.ordering_req, + is_reversed: value.is_reversed, + name: value.name.as_str(), + is_distinct: value.is_distinct, + exprs: &value.exprs, + } + } +} diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs index 0c5eb9475d78..eaf4b991477a 100644 --- a/datafusion/ffi/src/udaf/groups_accumulator.rs +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -1,7 +1,21 @@ -use std::{ - ffi::c_void, - sync::{Arc, Mutex}, -}; +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::ffi::c_void; use abi_stable::{ std_types::{ROption, RResult, RString, RVec}, @@ -10,14 +24,12 @@ use abi_stable::{ use arrow::{ array::{Array, ArrayRef, BooleanArray}, error::ArrowError, - ffi::{from_ffi, to_ffi, FFI_ArrowArray}, + ffi::to_ffi, }; use datafusion::{ error::{DataFusionError, Result}, logical_expr::{EmitTo, GroupsAccumulator}, - scalar::ScalarValue, }; -use prost::Message; use crate::{ arrow_wrappers::{WrappedArray, WrappedSchema}, @@ -66,10 +78,6 @@ pub struct FFI_GroupsAccumulator { pub supports_convert_to_state: bool, - /// Used to create a clone on the provider of the accumulator. This should - /// only need to be called by the receiver of the accumulator. - // pub clone: unsafe extern "C" fn(accumulator: &Self) -> Self, - /// Release the memory of the private data when it is no longer being used. pub release: unsafe extern "C" fn(accumulator: &mut Self), @@ -112,7 +120,7 @@ unsafe extern "C" fn update_batch_fn_wrapper( } } }).map(|arr| arr.into_data()); - let opt_filter = maybe_filter.map(|arr| BooleanArray::from(arr)); + let opt_filter = maybe_filter.map(BooleanArray::from); rresult!(accum_data.accumulator.update_batch( &values_arrays, @@ -181,7 +189,7 @@ unsafe extern "C" fn merge_batch_fn_wrapper( } } }).map(|arr| arr.into_data()); - let opt_filter = maybe_filter.map(|arr| BooleanArray::from(arr)); + let opt_filter = maybe_filter.map(BooleanArray::from); rresult!(accum_data.accumulator.merge_batch( &values_arrays, @@ -212,7 +220,7 @@ unsafe extern "C" fn convert_to_state_fn_wrapper( None } } - }).map(|arr| arr.into_data()).map(|arr| BooleanArray::from(arr)); + }).map(|arr| arr.into_data()).map(BooleanArray::from); let state = rresult_return!(accum_data .accumulator diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index 75b200b5d147..d5def52cfc31 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -15,56 +15,39 @@ // specific language governing permissions and limitations // under the License. -use std::{ - ffi::c_void, - sync::{Arc, Mutex}, -}; +use std::{ffi::c_void, sync::Arc}; use abi_stable::{ - std_types::{RResult, RStr, RString, RVec}, + std_types::{ROption, RResult, RStr, RString, RVec}, StableAbi, }; -use accumulator::FFI_Accumulator; -use accumulator_args::FFI_AccumulatorArgs; -use arrow::datatypes::{DataType, Field, SchemaRef}; -use arrow::ffi::{from_ffi, to_ffi, FFI_ArrowSchema}; +use accumulator::{FFI_Accumulator, ForeignAccumulator}; +use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; +use arrow::datatypes::{DataType, Field}; +use arrow::ffi::FFI_ArrowSchema; use datafusion::{ error::DataFusionError, logical_expr::{ function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, utils::AggregateOrderSensitivity, - Accumulator, GroupsAccumulator, ReversedUDAF, + Accumulator, GroupsAccumulator, }, - physical_plan::aggregates::order, - prelude::SessionContext, }; use datafusion::{ error::Result, - logical_expr::{ - AggregateUDF, AggregateUDFImpl, ColumnarValue, ScalarFunctionArgs, Signature, - }, -}; -use datafusion_proto::{ - physical_plan::{ - from_proto::{parse_physical_exprs, parse_physical_sort_exprs}, - to_proto::{ - serialize_physical_expr, serialize_physical_exprs, - serialize_physical_sort_exprs, - }, - DefaultPhysicalExtensionCodec, - }, - protobuf::{PhysicalAggregateExprNode, PhysicalSortExprNodeCollection}, + logical_expr::{AggregateUDF, AggregateUDFImpl, Signature}, }; -use groups_accumulator::FFI_GroupsAccumulator; +use datafusion_proto_common::from_proto::parse_proto_fields_to_fields; +use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator}; use crate::{ - arrow_wrappers::{WrappedArray, WrappedSchema}, + arrow_wrappers::WrappedSchema, df_result, rresult, rresult_return, signature::{ - self, rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped, FFI_Signature, + rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped, FFI_Signature, }, }; -use prost::Message; +use prost::{DecodeError, Message}; mod accumulator; mod accumulator_args; @@ -97,6 +80,13 @@ pub struct FFI_AggregateUDF { args: FFI_AccumulatorArgs, ) -> RResult, + pub create_sliding_accumulator: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, + ) -> RResult, + + #[allow(clippy::type_complexity)] pub state_fields: unsafe extern "C" fn( udaf: &FFI_AggregateUDF, name: &RStr, @@ -108,9 +98,18 @@ pub struct FFI_AggregateUDF { pub create_groups_accumulator: unsafe extern "C" fn( - &FFI_AggregateUDF, + udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs, - ) -> RResult, + ) -> RResult, + + pub with_beneficial_ordering: + unsafe extern "C" fn( + udaf: &FFI_AggregateUDF, + beneficial_ordering: bool, + ) -> RResult, RString>, + + pub order_sensitivity: + unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity, /// Used to create a clone on the provider of the udaf. This should /// only need to be called by the receiver of the udaf. @@ -131,25 +130,23 @@ pub struct AggregateUDFPrivateData { pub udaf: Arc, } -unsafe extern "C" fn name_fn_wrapper(udaf: &FFI_AggregateUDF) -> RString { - let private_data = udaf.private_data as *const AggregateUDFPrivateData; - let udaf = &(*private_data).udaf; - - udaf.name().into() +impl FFI_AggregateUDF { + unsafe fn inner(&self) -> &Arc { + let private_data = self.private_data as *const AggregateUDFPrivateData; + &(*private_data).udaf + } } unsafe extern "C" fn signature_fn_wrapper( udaf: &FFI_AggregateUDF, ) -> RResult { - let private_data = udaf.private_data as *const AggregateUDFPrivateData; - let udaf = &(*private_data).udaf; + let udaf = udaf.inner(); rresult!(udaf.signature().try_into()) } unsafe extern "C" fn aliases_fn_wrapper(udaf: &FFI_AggregateUDF) -> RVec { - let private_data = udaf.private_data as *const AggregateUDFPrivateData; - let udaf = &(*private_data).udaf; + let udaf = udaf.inner(); udaf.aliases().iter().map(|s| s.to_owned().into()).collect() } @@ -158,8 +155,7 @@ unsafe extern "C" fn return_type_fn_wrapper( udaf: &FFI_AggregateUDF, arg_types: RVec, ) -> RResult { - let private_data = udaf.private_data as *const AggregateUDFPrivateData; - let udaf = &(*private_data).udaf; + let udaf = udaf.inner(); let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); @@ -175,27 +171,38 @@ unsafe extern "C" fn accumulator_fn_wrapper( udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs, ) -> RResult { - let private_data = udaf.private_data as *const AggregateUDFPrivateData; - let udaf = &(*private_data).udaf; + let udaf = udaf.inner(); - let accumulator_args = rresult_return!(args.to_accumulator_args()); + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); rresult!(udaf - .accumulator(accumulator_args) + .accumulator(accumulator_args.into()) .map(FFI_Accumulator::from)) } -unsafe extern "C" fn create_groups_accumulator_fn_wrapper( +unsafe extern "C" fn create_sliding_accumulator_fn_wrapper( udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs, ) -> RResult { - let private_data = udaf.private_data as *const AggregateUDFPrivateData; - let udaf = &(*private_data).udaf; + let udaf = udaf.inner(); - let accumulator_args = rresult_return!(args.to_accumulator_args()); + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); rresult!(udaf - .create_groups_accumulator(accumulator_args) + .create_sliding_accumulator(accumulator_args.into()) + .map(FFI_Accumulator::from)) +} + +unsafe extern "C" fn create_groups_accumulator_fn_wrapper( + udaf: &FFI_AggregateUDF, + args: FFI_AccumulatorArgs, +) -> RResult { + let udaf = udaf.inner(); + + let accumulator_args = &rresult_return!(ForeignAccumulatorArgs::try_from(args)); + + rresult!(udaf + .create_groups_accumulator(accumulator_args.into()) .map(FFI_GroupsAccumulator::from)) } @@ -203,27 +210,86 @@ unsafe extern "C" fn groups_accumulator_supported_fn_wrapper( udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs, ) -> bool { - let private_data = udaf.private_data as *const AggregateUDFPrivateData; - let udaf = &(*private_data).udaf; + let udaf = udaf.inner(); - args.to_accumulator_args() - .map(|a| udaf.groups_accumulator_supported(a)) + ForeignAccumulatorArgs::try_from(args) + .map(|a| udaf.groups_accumulator_supported((&a).into())) .unwrap_or_else(|e| { log::warn!("Unable to parse accumulator args. {}", e); false }) } +unsafe extern "C" fn with_beneficial_ordering_fn_wrapper( + udaf: &FFI_AggregateUDF, + beneficial_ordering: bool, +) -> RResult, RString> { + let udaf = udaf.inner().as_ref().clone(); + + let result = rresult_return!(udaf.with_beneficial_ordering(beneficial_ordering)); + let result = rresult_return!(result + .map(|func| func.with_beneficial_ordering(beneficial_ordering)) + .transpose()) + .flatten() + .map(|func| FFI_AggregateUDF::from(Arc::new(func))); + + RResult::ROk(result.into()) +} + +unsafe extern "C" fn state_fields_fn_wrapper( + udaf: &FFI_AggregateUDF, + name: &RStr, + input_types: RVec, + return_type: WrappedSchema, + ordering_fields: RVec>, + is_distinct: bool, +) -> RResult>, RString> { + let udaf = udaf.inner(); + + let input_types = &rresult_return!(rvec_wrapped_to_vec_datatype(&input_types)); + let return_type = &rresult_return!(DataType::try_from(&return_type.0)); + + let ordering_fields = &rresult_return!(ordering_fields + .into_iter() + .map(|field_bytes| datafusion_proto_common::Field::decode(field_bytes.as_ref())) + .collect::, DecodeError>>()); + + let ordering_fields = &rresult_return!(parse_proto_fields_to_fields(ordering_fields)); + + let args = StateFieldsArgs { + name: name.as_str(), + input_types, + return_type, + ordering_fields, + is_distinct, + }; + + let state_fields = rresult_return!(udaf.state_fields(args)); + let state_fields = rresult_return!(state_fields + .iter() + .map(datafusion_proto::protobuf::Field::try_from) + .map(|v| v.map_err(DataFusionError::from)) + .collect::>>()) + .into_iter() + .map(|field| field.encode_to_vec().into()) + .collect(); + + RResult::ROk(state_fields) +} + +unsafe extern "C" fn order_sensitivity_fn_wrapper( + udaf: &FFI_AggregateUDF, +) -> FFI_AggregateOrderSensitivity { + udaf.inner().order_sensitivity().into() +} + unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) { let private_data = Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData); drop(private_data); } unsafe extern "C" fn clone_fn_wrapper(udaf: &FFI_AggregateUDF) -> FFI_AggregateUDF { - let private_data = udaf.private_data as *const AggregateUDFPrivateData; - let udaf_data = &(*private_data); - - Arc::clone(&udaf_data.udaf).into() + Arc::clone(udaf.inner()).into() } impl Clone for FFI_AggregateUDF { @@ -246,8 +312,12 @@ impl From> for FFI_AggregateUDF { aliases: aliases_fn_wrapper, return_type: return_type_fn_wrapper, accumulator: accumulator_fn_wrapper, + create_sliding_accumulator: create_sliding_accumulator_fn_wrapper, create_groups_accumulator: create_groups_accumulator_fn_wrapper, groups_accumulator_supported: groups_accumulator_supported_fn_wrapper, + with_beneficial_ordering: with_beneficial_ordering_fn_wrapper, + state_fields: state_fields_fn_wrapper, + order_sensitivity: order_sensitivity_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, @@ -328,8 +398,11 @@ impl AggregateUDFImpl for ForeignAggregateUDF { fn accumulator(&self, acc_args: AccumulatorArgs) -> Result> { let args = acc_args.try_into()?; - - unsafe { df_result!((self.udaf.accumulator)(&self.udaf, args)) } + unsafe { + df_result!((self.udaf.accumulator)(&self.udaf, args)).map(|accum| { + Box::new(ForeignAccumulator::from(accum)) as Box + }) + } } fn state_fields(&self, args: StateFieldsArgs) -> Result> { @@ -363,10 +436,8 @@ impl AggregateUDFImpl for ForeignAggregateUDF { }) .collect::>>()?; - datafusion_proto_common::from_proto::parse_proto_fields_to_fields( - fields.iter(), - ) - .map_err(|e| DataFusionError::Execution(e.to_string())) + parse_proto_fields_to_fields(fields.iter()) + .map_err(|e| DataFusionError::Execution(e.to_string())) } } @@ -388,7 +459,14 @@ impl AggregateUDFImpl for ForeignAggregateUDF { ) -> Result> { let args = FFI_AccumulatorArgs::try_from(args)?; - unsafe { df_result!((self.udaf.accumulator)(&self.udaf, args)) } + unsafe { + df_result!((self.udaf.create_groups_accumulator)(&self.udaf, args)).map( + |accum| { + Box::new(ForeignGroupsAccumulator::from(accum)) + as Box + }, + ) + } } fn aliases(&self) -> &[String] { @@ -399,32 +477,40 @@ impl AggregateUDFImpl for ForeignAggregateUDF { &self, args: AccumulatorArgs, ) -> Result> { + let args = args.try_into()?; + unsafe { + df_result!((self.udaf.create_sliding_accumulator)(&self.udaf, args)).map( + |accum| Box::new(ForeignAccumulator::from(accum)) as Box, + ) + } } fn with_beneficial_ordering( self: Arc, - _beneficial_ordering: bool, + beneficial_ordering: bool, ) -> Result>> { - } - - fn order_sensitivity(&self) -> AggregateOrderSensitivity {} - - fn simplify(&self) -> Option {} - - fn reverse_expr(&self) -> ReversedUDAF {} - - fn coerce_types(&self, _arg_types: &[DataType]) -> Result> {} - - fn equals(&self, other: &dyn AggregateUDFImpl) -> bool {} + unsafe { + let result = df_result!((self.udaf.with_beneficial_ordering)( + &self.udaf, + beneficial_ordering + ))? + .into_option(); - fn is_descending(&self) -> Option {} + let result = result + .map(|func| ForeignAggregateUDF::try_from(&func)) + .transpose()?; - fn value_from_stats(&self, _statistics_args: &StatisticsArgs) -> Option { + Ok(result.map(|func| Arc::new(func) as Arc)) + } } - fn default_value(&self, data_type: &DataType) -> Result {} + fn order_sensitivity(&self) -> AggregateOrderSensitivity { + unsafe { (self.udaf.order_sensitivity)(&self.udaf).into() } + } - fn set_monotonicity(&self, _data_type: &DataType) -> SetMonotonicity {} + fn simplify(&self) -> Option { + None + } } #[cfg(test)] @@ -433,7 +519,7 @@ mod tests { #[test] fn test_round_trip_udaf() -> Result<()> { - let original_udaf = datafusion::functions::math::abs::AbsFunc::new(); + let original_udaf = datafusion::functions_aggregate::sum::Sum::new(); let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); @@ -445,3 +531,32 @@ mod tests { Ok(()) } } + +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub enum FFI_AggregateOrderSensitivity { + Insensitive, + HardRequirement, + Beneficial, +} + +impl From for AggregateOrderSensitivity { + fn from(value: FFI_AggregateOrderSensitivity) -> Self { + match value { + FFI_AggregateOrderSensitivity::Insensitive => Self::Insensitive, + FFI_AggregateOrderSensitivity::HardRequirement => Self::HardRequirement, + FFI_AggregateOrderSensitivity::Beneficial => Self::Beneficial, + } + } +} + +impl From for FFI_AggregateOrderSensitivity { + fn from(value: AggregateOrderSensitivity) -> Self { + match value { + AggregateOrderSensitivity::Insensitive => Self::Insensitive, + AggregateOrderSensitivity::HardRequirement => Self::HardRequirement, + AggregateOrderSensitivity::Beneficial => Self::Beneficial, + } + } +} From 95c3c799d3a5209c5476a2898ef9fc3626f70b90 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 20 Feb 2025 17:57:54 -0500 Subject: [PATCH 04/32] Clean up after rebase --- datafusion/ffi/src/udaf/mod.rs | 54 ++++++++++++---------------------- 1 file changed, 18 insertions(+), 36 deletions(-) diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index d5def52cfc31..4107f7bd0ff4 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -43,9 +43,8 @@ use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator}; use crate::{ arrow_wrappers::WrappedSchema, df_result, rresult, rresult_return, - signature::{ - rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped, FFI_Signature, - }, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, }; use prost::{DecodeError, Message}; @@ -58,12 +57,14 @@ mod groups_accumulator; #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] pub struct FFI_AggregateUDF { - /// Return the udaf name. + /// FFI equivalent to the `name` of a [`AggregateUDF`] pub name: RString, - pub signature: unsafe extern "C" fn(udaf: &Self) -> RResult, + /// FFI equivalent to the `name` of a [`AggregateUDF`] + pub aliases: RVec, - pub aliases: unsafe extern "C" fn(udaf: &Self) -> RVec, + /// FFI equivalent to the `name` of a [`AggregateUDF`] + pub volatility: FFI_Volatility, pub return_type: unsafe extern "C" fn( udaf: &Self, @@ -137,20 +138,6 @@ impl FFI_AggregateUDF { } } -unsafe extern "C" fn signature_fn_wrapper( - udaf: &FFI_AggregateUDF, -) -> RResult { - let udaf = udaf.inner(); - - rresult!(udaf.signature().try_into()) -} - -unsafe extern "C" fn aliases_fn_wrapper(udaf: &FFI_AggregateUDF) -> RVec { - let udaf = udaf.inner(); - - udaf.aliases().iter().map(|s| s.to_owned().into()).collect() -} - unsafe extern "C" fn return_type_fn_wrapper( udaf: &FFI_AggregateUDF, arg_types: RVec, @@ -301,15 +288,17 @@ impl Clone for FFI_AggregateUDF { impl From> for FFI_AggregateUDF { fn from(udaf: Arc) -> Self { let name = udaf.name().into(); + let aliases = udaf.aliases().iter().map(|a| a.to_owned().into()).collect(); let is_nullable = udaf.is_nullable(); + let volatility = udaf.signature().volatility.into(); let private_data = Box::new(AggregateUDFPrivateData { udaf }); Self { name, is_nullable, - signature: signature_fn_wrapper, - aliases: aliases_fn_wrapper, + volatility, + aliases, return_type: return_type_fn_wrapper, accumulator: accumulator_fn_wrapper, create_sliding_accumulator: create_sliding_accumulator_fn_wrapper, @@ -351,21 +340,14 @@ impl TryFrom<&FFI_AggregateUDF> for ForeignAggregateUDF { type Error = DataFusionError; fn try_from(udaf: &FFI_AggregateUDF) -> Result { - unsafe { - let ffi_signature = df_result!((udaf.signature)(udaf))?; - let signature = (&ffi_signature).try_into()?; - - let aliases = (udaf.aliases)(udaf) - .into_iter() - .map(|s| s.to_string()) - .collect(); + let signature = Signature::user_defined((&udaf.volatility).into()); + let aliases = udaf.aliases.iter().map(|s| s.to_string()).collect(); - Ok(Self { - udaf: udaf.clone(), - signature, - aliases, - }) - } + Ok(Self { + udaf: udaf.clone(), + signature, + aliases, + }) } } From a91ee5b6433e5b88d058eb5b04f54bec980a2d37 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 21 Feb 2025 19:57:24 -0500 Subject: [PATCH 05/32] Add unit test for FFI Accumulator Args --- datafusion/ffi/src/udaf/accumulator_args.rs | 37 +++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs index f15bed5aa2c2..b3933c267025 100644 --- a/datafusion/ffi/src/udaf/accumulator_args.rs +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -152,3 +152,40 @@ impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> { } } } + +#[cfg(test)] +mod tests { + use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; + use arrow::datatypes::{DataType, Schema}; + use datafusion::{ + error::Result, logical_expr::function::AccumulatorArgs, + physical_expr::LexOrdering, + }; + + #[test] + fn test_round_trip_accumulator_args() -> Result<()> { + let orig_args = AccumulatorArgs { + return_type: &DataType::Float64, + schema: &Schema::empty(), + ignore_nulls: false, + ordering_req: &LexOrdering::new(vec![]), + is_reversed: false, + name: "round_trip", + is_distinct: true, + exprs: &[], + }; + let orig_str = format!("{:?}", orig_args); + + let ffi_args: FFI_AccumulatorArgs = orig_args.try_into()?; + let foreign_args: ForeignAccumulatorArgs = ffi_args.try_into()?; + let round_trip_args: AccumulatorArgs = (&foreign_args).into(); + + let round_trip_str = format!("{:?}", round_trip_args); + + // Since AccumulatorArgs doesn't implement Eq, simply compare + // the debug strings. + assert_eq!(orig_str, round_trip_str); + + Ok(()) + } +} From d15b3a1a8581ce0271af4bfa944b1ab836aea275 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 23 Feb 2025 03:54:37 -0500 Subject: [PATCH 06/32] Adding unit tests and fixing memory errors in aggregate ffi udf --- .../expr-common/src/groups_accumulator.rs | 2 +- datafusion/ffi/src/tests/mod.rs | 8 +- datafusion/ffi/src/tests/udf_udaf_udwf.rs | 11 ++- datafusion/ffi/src/udaf/accumulator.rs | 83 +++++++++++++------ datafusion/ffi/src/udaf/groups_accumulator.rs | 77 ++++++++++++++++- datafusion/ffi/src/udaf/mod.rs | 33 ++++++++ datafusion/ffi/tests/ffi_integration.rs | 48 ++++++++++- 7 files changed, 229 insertions(+), 33 deletions(-) diff --git a/datafusion/expr-common/src/groups_accumulator.rs b/datafusion/expr-common/src/groups_accumulator.rs index 5ff1c1d07216..9bcc1edff882 100644 --- a/datafusion/expr-common/src/groups_accumulator.rs +++ b/datafusion/expr-common/src/groups_accumulator.rs @@ -21,7 +21,7 @@ use arrow::array::{ArrayRef, BooleanArray}; use datafusion_common::{not_impl_err, Result}; /// Describes how many rows should be emitted during grouping. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum EmitTo { /// Emit all groups All, diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index 7a36ee52bdb4..e4b162644d5e 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -29,6 +29,8 @@ use catalog::create_catalog_provider; use crate::{catalog_provider::FFI_CatalogProvider, udtf::FFI_TableFunction}; +use crate::udaf::FFI_AggregateUDF; + use super::{table_provider::FFI_TableProvider, udf::FFI_ScalarUDF}; use arrow::array::RecordBatch; use async_provider::create_async_table_provider; @@ -37,7 +39,7 @@ use datafusion::{ common::record_batch, }; use sync_provider::create_sync_table_provider; -use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_random_func, create_ffi_table_func}; +use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_avg_func, create_ffi_random_func, create_ffi_table_func}; mod async_provider; pub mod catalog; @@ -65,6 +67,9 @@ pub struct ForeignLibraryModule { pub create_table_function: extern "C" fn() -> FFI_TableFunction, + /// Create an aggregate UDF + pub create_udaf: extern "C" fn() -> FFI_AggregateUDF, + pub version: extern "C" fn() -> u64, } @@ -112,6 +117,7 @@ pub fn get_foreign_library_module() -> ForeignLibraryModuleRef { create_scalar_udf: create_ffi_abs_func, create_nullary_udf: create_ffi_random_func, create_table_function: create_ffi_table_func, + create_udaf: create_ffi_avg_func, version: super::version, } .leak_into_prefix() diff --git a/datafusion/ffi/src/tests/udf_udaf_udwf.rs b/datafusion/ffi/src/tests/udf_udaf_udwf.rs index c3cb1bcc3533..3edc383f4ca3 100644 --- a/datafusion/ffi/src/tests/udf_udaf_udwf.rs +++ b/datafusion/ffi/src/tests/udf_udaf_udwf.rs @@ -15,12 +15,13 @@ // specific language governing permissions and limitations // under the License. -use crate::{udf::FFI_ScalarUDF, udtf::FFI_TableFunction}; +use crate::{udaf::FFI_AggregateUDF, udf::FFI_ScalarUDF, udtf::FFI_TableFunction}; use datafusion::{ catalog::TableFunctionImpl, functions::math::{abs::AbsFunc, random::RandomFunc}, + functions_aggregate::sum::Sum, functions_table::generate_series::RangeFunc, - logical_expr::ScalarUDF, + logical_expr::{AggregateUDF, ScalarUDF}, }; use std::sync::Arc; @@ -42,3 +43,9 @@ pub(crate) extern "C" fn create_ffi_table_func() -> FFI_TableFunction { FFI_TableFunction::new(udtf, None) } + +pub(crate) extern "C" fn create_ffi_avg_func() -> FFI_AggregateUDF { + let udaf: Arc = Arc::new(Sum::new().into()); + + udaf.into() +} diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs index 6ebe67a4b1c5..c4f8edfeafc5 100644 --- a/datafusion/ffi/src/udaf/accumulator.rs +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -58,7 +58,7 @@ pub struct FFI_Accumulator { values: RVec, ) -> RResult<(), RString>, - pub supports_retract_batch: unsafe extern "C" fn(accumulator: &Self) -> bool, + pub supports_retract_batch: bool, /// Release the memory of the private data when it is no longer being used. pub release: unsafe extern "C" fn(accumulator: &mut Self), @@ -161,35 +161,17 @@ unsafe extern "C" fn retract_batch_fn_wrapper( rresult!(accum_data.accumulator.retract_batch(&values)) } -unsafe extern "C" fn supports_retract_batch_fn_wrapper( - accumulator: &FFI_Accumulator, -) -> bool { - let private_data = accumulator.private_data as *mut AccumulatorPrivateData; - let accum_data = &mut (*private_data); - accum_data.accumulator.supports_retract_batch() -} - unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) { let private_data = Box::from_raw(accumulator.private_data as *mut AccumulatorPrivateData); drop(private_data); } -// unsafe extern "C" fn clone_fn_wrapper(accumulator: &FFI_Accumulator) -> FFI_Accumulator { -// let private_data = accumulator.private_data as *const AccumulatorPrivateData; -// let accum_data = &(*private_data); - -// Box::new(accum_data.accumulator).into() -// } - -// impl Clone for FFI_Accumulator { -// fn clone(&self) -> Self { -// unsafe { (self.clone)(self) } -// } -// } - impl From> for FFI_Accumulator { fn from(accumulator: Box) -> Self { + let supports_retract_batch = accumulator.supports_retract_batch(); + let private_data = AccumulatorPrivateData { accumulator }; + Self { update_batch: update_batch_fn_wrapper, evaluate: evaluate_fn_wrapper, @@ -197,11 +179,11 @@ impl From> for FFI_Accumulator { state: state_fn_wrapper, merge_batch: merge_batch_fn_wrapper, retract_batch: retract_batch_fn_wrapper, - supports_retract_batch: supports_retract_batch_fn_wrapper, + supports_retract_batch, // clone: clone_fn_wrapper, release: release_fn_wrapper, - private_data: Box::into_raw(accumulator) as *mut c_void, + private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, } } } @@ -308,6 +290,57 @@ impl Accumulator for ForeignAccumulator { } fn supports_retract_batch(&self) -> bool { - unsafe { (self.accumulator.supports_retract_batch)(&self.accumulator) } + self.accumulator.supports_retract_batch + } +} + +#[cfg(test)] +mod tests { + use arrow::array::{make_array, Array}; + use datafusion::{ + common::create_array, error::Result, + functions_aggregate::average::AvgAccumulator, logical_expr::Accumulator, + scalar::ScalarValue, + }; + + use super::{FFI_Accumulator, ForeignAccumulator}; + + #[test] + fn test_foreign_avg_accumulator() -> Result<()> { + let boxed_accum: Box = Box::new(AvgAccumulator::default()); + let ffi_accum: FFI_Accumulator = boxed_accum.into(); + let mut foreign_accum: ForeignAccumulator = ffi_accum.into(); + + // Send in an array to average. There are 5 values and it should average to 30.0 + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + foreign_accum.update_batch(&[values])?; + + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(30.0))); + + let state = foreign_accum.state()?; + assert_eq!(state.len(), 2); + assert_eq!(state[0], ScalarValue::UInt64(Some(5))); + assert_eq!(state[1], ScalarValue::Float64(Some(150.0))); + + // To verify merging batches works, create a second state to add in + // This should cause our average to go down to 25.0 + let second_states = vec![ + make_array(create_array!(UInt64, vec![1]).to_data()), + make_array(create_array!(Float64, vec![0.0]).to_data()), + ]; + + foreign_accum.merge_batch(&second_states)?; + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(25.0))); + + // If we remove a batch that is equivalent to the state we added + // we should go back to our original value of 30.0 + let values = create_array!(Float64, vec![0.0]); + foreign_accum.retract_batch(&[values])?; + let avg = foreign_accum.evaluate()?; + assert_eq!(avg, ScalarValue::Float64(Some(30.0))); + + Ok(()) } } diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs index eaf4b991477a..d5559e24b5f6 100644 --- a/datafusion/ffi/src/udaf/groups_accumulator.rs +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -253,6 +253,9 @@ unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_GroupsAccumulator) impl From> for FFI_GroupsAccumulator { fn from(accumulator: Box) -> Self { + let supports_convert_to_state = accumulator.supports_convert_to_state(); + let private_data = GroupsAccumulatorPrivateData { accumulator }; + Self { update_batch: update_batch_fn_wrapper, evaluate: evaluate_fn_wrapper, @@ -260,10 +263,10 @@ impl From> for FFI_GroupsAccumulator { state: state_fn_wrapper, merge_batch: merge_batch_fn_wrapper, convert_to_state: convert_to_state_fn_wrapper, - supports_convert_to_state: accumulator.supports_convert_to_state(), + supports_convert_to_state, release: release_fn_wrapper, - private_data: Box::into_raw(accumulator) as *mut c_void, + private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, } } } @@ -451,3 +454,73 @@ impl From for EmitTo { } } } + +#[cfg(test)] +mod tests { + use arrow::array::{make_array, Array, Float64Array}; + use datafusion::{ + common::create_array, + error::Result, + functions_aggregate::stddev::StddevGroupsAccumulator, + logical_expr::{EmitTo, GroupsAccumulator}, + physical_plan::expressions::StatsType, + }; + + use super::{FFI_EmitTo, FFI_GroupsAccumulator, ForeignGroupsAccumulator}; + + #[test] + fn test_foreign_avg_accumulator() -> Result<()> { + let boxed_accum: Box = + Box::new(StddevGroupsAccumulator::new(StatsType::Population)); + let ffi_accum: FFI_GroupsAccumulator = boxed_accum.into(); + let mut foreign_accum: ForeignGroupsAccumulator = ffi_accum.into(); + + // Send in an array to evaluate. We want a mean of 30 and standard deviation of 4. + let values = create_array!(Float64, vec![26., 26., 34., 34.]); + foreign_accum.update_batch(&[values], &[0; 4], None, 1)?; + + let groups_avg = foreign_accum.evaluate(EmitTo::All)?; + let groups_avg = groups_avg.as_any().downcast_ref::().unwrap(); + let expected = 4.0; + assert_eq!(groups_avg.len(), 1); + assert!((groups_avg.value(0) - expected).abs() < 0.0001); + + let state = foreign_accum.state(EmitTo::All)?; + assert_eq!(state.len(), 3); + + // To verify merging batches works, create a second state to add in + // This should cause our average to go down to 25.0 + let second_states = vec![ + make_array(create_array!(UInt64, vec![1]).to_data()), + make_array(create_array!(Float64, vec![30.0]).to_data()), + make_array(create_array!(Float64, vec![64.0]).to_data()), + ]; + + foreign_accum.merge_batch(&second_states, &[0], None, 1)?; + let avg = foreign_accum.evaluate(EmitTo::All)?; + assert_eq!(avg.len(), 1); + assert_eq!( + avg.as_ref(), + make_array(create_array!(Float64, vec![8.0]).to_data()).as_ref() + ); + + Ok(()) + } + + fn test_emit_to_round_trip(value: EmitTo) -> Result<()> { + let ffi_value: FFI_EmitTo = value.into(); + let round_trip_value: EmitTo = ffi_value.into(); + + assert_eq!(value, round_trip_value); + Ok(()) + } + + /// This test ensures all enum values are properly translated + #[test] + fn test_all_emit_to_round_trip() -> Result<()> { + test_emit_to_round_trip(EmitTo::All)?; + test_emit_to_round_trip(EmitTo::First(10))?; + + Ok(()) + } +} diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index 4107f7bd0ff4..018b2b07c2eb 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -29,6 +29,7 @@ use datafusion::{ error::DataFusionError, logical_expr::{ function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, + type_coercion::functions::data_types_with_aggregate_udf, utils::AggregateOrderSensitivity, Accumulator, GroupsAccumulator, }, @@ -112,6 +113,15 @@ pub struct FFI_AggregateUDF { pub order_sensitivity: unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity, + /// Performs type coersion. To simply this interface, all UDFs are treated as having + /// user defined signatures, which will in turn call coerce_types to be called. This + /// call should be transparent to most users as the internal function performs the + /// appropriate calls on the underlying [`ScalarUDF`] + pub coerce_types: unsafe extern "C" fn( + udf: &Self, + arg_types: RVec, + ) -> RResult, RString>, + /// Used to create a clone on the provider of the udaf. This should /// only need to be called by the receiver of the udaf. pub clone: unsafe extern "C" fn(udaf: &Self) -> Self, @@ -270,6 +280,19 @@ unsafe extern "C" fn order_sensitivity_fn_wrapper( udaf.inner().order_sensitivity().into() } +unsafe extern "C" fn coerce_types_fn_wrapper( + udaf: &FFI_AggregateUDF, + arg_types: RVec, +) -> RResult, RString> { + let udaf = udaf.inner(); + + let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); + + let return_types = rresult_return!(data_types_with_aggregate_udf(&arg_types, udaf)); + + rresult!(vec_datatype_to_rvec_wrapped(&return_types)) +} + unsafe extern "C" fn release_fn_wrapper(udaf: &mut FFI_AggregateUDF) { let private_data = Box::from_raw(udaf.private_data as *mut AggregateUDFPrivateData); drop(private_data); @@ -307,6 +330,7 @@ impl From> for FFI_AggregateUDF { with_beneficial_ordering: with_beneficial_ordering_fn_wrapper, state_fields: state_fields_fn_wrapper, order_sensitivity: order_sensitivity_fn_wrapper, + coerce_types: coerce_types_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(private_data) as *mut c_void, @@ -493,6 +517,15 @@ impl AggregateUDFImpl for ForeignAggregateUDF { fn simplify(&self) -> Option { None } + + fn coerce_types(&self, arg_types: &[DataType]) -> Result> { + unsafe { + let arg_types = vec_datatype_to_rvec_wrapped(arg_types)?; + let result_types = + df_result!((self.udaf.coerce_types)(&self.udaf, arg_types))?; + Ok(rvec_wrapped_to_vec_datatype(&result_types)?) + } + } } #[cfg(test)] diff --git a/datafusion/ffi/tests/ffi_integration.rs b/datafusion/ffi/tests/ffi_integration.rs index c6df324e9a17..23cf50c3eae7 100644 --- a/datafusion/ffi/tests/ffi_integration.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -21,11 +21,15 @@ mod tests { use datafusion::error::{DataFusionError, Result}; - use datafusion::prelude::SessionContext; + use datafusion::prelude::{col, SessionContext}; use datafusion_ffi::catalog_provider::ForeignCatalogProvider; use datafusion_ffi::table_provider::ForeignTableProvider; - use datafusion_ffi::tests::create_record_batch; + use datafusion_ffi::tests::{create_record_batch, ForeignLibraryModuleRef}; use datafusion_ffi::tests::utils::get_module; + use datafusion::logical_expr::{AggregateUDF, ScalarUDF}; + use datafusion_ffi::udaf::ForeignAggregateUDF; + use datafusion_ffi::udf::ForeignScalarUDF; + use std::path::Path; use std::sync::Arc; /// It is important that this test is in the `tests` directory and not in the @@ -96,4 +100,44 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_ffi_udaf() -> Result<()> { + let module = get_module()?; + + let ffi_avg_func = + module.create_udaf().ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_udaf".to_string(), + ))?(); + let foreign_avg_func: ForeignAggregateUDF = (&ffi_avg_func).try_into()?; + + let udf: AggregateUDF = foreign_avg_func.into(); + + let ctx = SessionContext::default(); + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), + ("b", Float64, vec![1.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0]) + ) + .unwrap(); + + let df = ctx.read_batch(record_batch)?; + + let df = df + .aggregate( + vec![col("a")], + vec![udf.call(vec![col("b")]).alias("sum_b")], + )? + .sort_by(vec![col("a")])?; + + let result = df.collect().await?; + + let expected = record_batch!( + ("a", Int32, vec![1, 2, 4]), + ("sum_b", Float64, vec![1.0, 4.0, 16.0]) + )?; + + assert_eq!(result[0], expected); + + Ok(()) + } } From eb6a07282ebe814f4724413f90d34bf04e2126d3 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 23 Feb 2025 11:57:58 +0100 Subject: [PATCH 07/32] Working through additional unit and integration tests for UDAF ffi --- datafusion/ffi/src/tests/mod.rs | 12 ++-- datafusion/ffi/src/tests/udf_udaf_udwf.rs | 8 ++- datafusion/ffi/src/udaf/groups_accumulator.rs | 8 ++- datafusion/ffi/tests/ffi_integration.rs | 72 +++++++++++++++++-- 4 files changed, 87 insertions(+), 13 deletions(-) diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index e4b162644d5e..8605a84a3c8d 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -39,7 +39,7 @@ use datafusion::{ common::record_batch, }; use sync_provider::create_sync_table_provider; -use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_avg_func, create_ffi_random_func, create_ffi_table_func}; +use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_avg_func, create_ffi_random_func, create_ffi_stddev_func, create_ffi_table_func}; mod async_provider; pub mod catalog; @@ -67,8 +67,11 @@ pub struct ForeignLibraryModule { pub create_table_function: extern "C" fn() -> FFI_TableFunction, - /// Create an aggregate UDF - pub create_udaf: extern "C" fn() -> FFI_AggregateUDF, + /// Create an aggregate UDAF using sum + pub create_sum_udaf: extern "C" fn() -> FFI_AggregateUDF, + + /// Createa grouping UDAF using stddev + pub create_stddev_udaf: extern "C" fn() -> FFI_AggregateUDF, pub version: extern "C" fn() -> u64, } @@ -117,7 +120,8 @@ pub fn get_foreign_library_module() -> ForeignLibraryModuleRef { create_scalar_udf: create_ffi_abs_func, create_nullary_udf: create_ffi_random_func, create_table_function: create_ffi_table_func, - create_udaf: create_ffi_avg_func, + create_sum_udaf: create_ffi_avg_func, + create_stddev_udaf: create_ffi_stddev_func, version: super::version, } .leak_into_prefix() diff --git a/datafusion/ffi/src/tests/udf_udaf_udwf.rs b/datafusion/ffi/src/tests/udf_udaf_udwf.rs index 3edc383f4ca3..73caee695b33 100644 --- a/datafusion/ffi/src/tests/udf_udaf_udwf.rs +++ b/datafusion/ffi/src/tests/udf_udaf_udwf.rs @@ -19,7 +19,7 @@ use crate::{udaf::FFI_AggregateUDF, udf::FFI_ScalarUDF, udtf::FFI_TableFunction} use datafusion::{ catalog::TableFunctionImpl, functions::math::{abs::AbsFunc, random::RandomFunc}, - functions_aggregate::sum::Sum, + functions_aggregate::{stddev::Stddev, sum::Sum}, functions_table::generate_series::RangeFunc, logical_expr::{AggregateUDF, ScalarUDF}, }; @@ -49,3 +49,9 @@ pub(crate) extern "C" fn create_ffi_avg_func() -> FFI_AggregateUDF { udaf.into() } + +pub(crate) extern "C" fn create_ffi_stddev_func() -> FFI_AggregateUDF { + let udaf: Arc = Arc::new(Stddev::new().into()); + + udaf.into() +} diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs index d5559e24b5f6..ccdf4000bdc7 100644 --- a/datafusion/ffi/src/udaf/groups_accumulator.rs +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -476,8 +476,9 @@ mod tests { let mut foreign_accum: ForeignGroupsAccumulator = ffi_accum.into(); // Send in an array to evaluate. We want a mean of 30 and standard deviation of 4. - let values = create_array!(Float64, vec![26., 26., 34., 34.]); - foreign_accum.update_batch(&[values], &[0; 4], None, 1)?; + let values = create_array!(Float64, vec![26., 26., 34., 34., 0.0]); + let opt_filter = create_array!(Boolean, vec![true, true, true, true, false]); + foreign_accum.update_batch(&[values], &[0; 5], Some(opt_filter.as_ref()), 1)?; let groups_avg = foreign_accum.evaluate(EmitTo::All)?; let groups_avg = groups_avg.as_any().downcast_ref::().unwrap(); @@ -496,7 +497,8 @@ mod tests { make_array(create_array!(Float64, vec![64.0]).to_data()), ]; - foreign_accum.merge_batch(&second_states, &[0], None, 1)?; + let opt_filter = create_array!(Boolean, vec![true]); + foreign_accum.merge_batch(&second_states, &[0], Some(opt_filter.as_ref()), 1)?; let avg = foreign_accum.evaluate(EmitTo::All)?; assert_eq!(avg.len(), 1); assert_eq!( diff --git a/datafusion/ffi/tests/ffi_integration.rs b/datafusion/ffi/tests/ffi_integration.rs index 23cf50c3eae7..3ea65caadf02 100644 --- a/datafusion/ffi/tests/ffi_integration.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -20,6 +20,9 @@ #[cfg(feature = "integration-tests")] mod tests { + use abi_stable::library::RootModule; + use arrow::array::Float64Array; + use datafusion::common::record_batch; use datafusion::error::{DataFusionError, Result}; use datafusion::prelude::{col, SessionContext}; use datafusion_ffi::catalog_provider::ForeignCatalogProvider; @@ -106,12 +109,14 @@ mod tests { let module = get_module()?; let ffi_avg_func = - module.create_udaf().ok_or(DataFusionError::NotImplemented( - "External table provider failed to implement create_udaf".to_string(), - ))?(); + module + .create_sum_udaf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_udaf".to_string(), + ))?(); let foreign_avg_func: ForeignAggregateUDF = (&ffi_avg_func).try_into()?; - let udf: AggregateUDF = foreign_avg_func.into(); + let udaf: AggregateUDF = foreign_avg_func.into(); let ctx = SessionContext::default(); let record_batch = record_batch!( @@ -125,7 +130,7 @@ mod tests { let df = df .aggregate( vec![col("a")], - vec![udf.call(vec![col("b")]).alias("sum_b")], + vec![udaf.call(vec![col("b")]).alias("sum_b")], )? .sort_by(vec![col("a")])?; @@ -140,4 +145,61 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_ffi_grouping_udaf() -> Result<()> { + let module = get_module()?; + + let ffi_stddev_func = + module + .create_stddev_udaf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_udaf".to_string(), + ))?(); + let foreign_stddev_func: ForeignAggregateUDF = (&ffi_stddev_func).try_into()?; + + let udaf: AggregateUDF = foreign_stddev_func.into(); + + let ctx = SessionContext::default(); + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), + ( + "b", + Float64, + vec![ + 1.0, + 2.0, + 2.0 + 2.0_f64.sqrt(), + 4.0, + 4.0, + 4.0 + 3.0_f64.sqrt(), + 4.0 + 3.0_f64.sqrt() + ] + ) + ) + .unwrap(); + + let df = ctx.read_batch(record_batch)?; + + let df = df + .aggregate( + vec![col("a")], + vec![udaf.call(vec![col("b")]).alias("stddev_b")], + )? + .sort_by(vec![col("a")])?; + + let result = df.collect().await?; + let result = result[0].column_by_name("stddev_b").unwrap(); + let result = result + .as_any() + .downcast_ref::() + .unwrap() + .values(); + + assert!(result.first().unwrap().is_nan()); + assert!(result.get(1).unwrap() - 1.0 < 0.00001); + assert!(result.get(2).unwrap() - 1.0 < 0.00001); + + Ok(()) + } } From 9d31d1f0de2cd777131b3ed70a5ca55b1ef86362 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 23 Feb 2025 14:43:05 +0100 Subject: [PATCH 08/32] Switch to a accumulator that supports convert to state to get a little better coverage --- Cargo.lock | 1 + datafusion/ffi/Cargo.toml | 1 + datafusion/ffi/src/udaf/groups_accumulator.rs | 58 ++++++++++++------- 3 files changed, 38 insertions(+), 22 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d0a125cbdbbd..d706821781f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2242,6 +2242,7 @@ dependencies = [ "async-ffi", "async-trait", "datafusion", + "datafusion-functions-aggregate-common", "datafusion-proto", "datafusion-proto-common", "doc-comment", diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index 2dadc2e5d541..963cfc6cd1c4 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -46,6 +46,7 @@ async-trait = { workspace = true } datafusion = { workspace = true, default-features = false } datafusion-proto = { workspace = true } datafusion-proto-common = { workspace = true } +datafusion-functions-aggregate-common = { workspace = true } futures = { workspace = true } log = { workspace = true } prost = { workspace = true } diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs index ccdf4000bdc7..3f6b5def4f9b 100644 --- a/datafusion/ffi/src/udaf/groups_accumulator.rs +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -457,53 +457,67 @@ impl From for EmitTo { #[cfg(test)] mod tests { - use arrow::array::{make_array, Array, Float64Array}; + use arrow::array::{make_array, Array, BooleanArray}; use datafusion::{ common::create_array, error::Result, - functions_aggregate::stddev::StddevGroupsAccumulator, logical_expr::{EmitTo, GroupsAccumulator}, - physical_plan::expressions::StatsType, }; + use datafusion_functions_aggregate_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator; use super::{FFI_EmitTo, FFI_GroupsAccumulator, ForeignGroupsAccumulator}; #[test] fn test_foreign_avg_accumulator() -> Result<()> { let boxed_accum: Box = - Box::new(StddevGroupsAccumulator::new(StatsType::Population)); + Box::new(BooleanGroupsAccumulator::new(|a, b| a && b, true)); let ffi_accum: FFI_GroupsAccumulator = boxed_accum.into(); let mut foreign_accum: ForeignGroupsAccumulator = ffi_accum.into(); // Send in an array to evaluate. We want a mean of 30 and standard deviation of 4. - let values = create_array!(Float64, vec![26., 26., 34., 34., 0.0]); - let opt_filter = create_array!(Boolean, vec![true, true, true, true, false]); - foreign_accum.update_batch(&[values], &[0; 5], Some(opt_filter.as_ref()), 1)?; + let values = create_array!(Boolean, vec![true, true, true, false, true, true]); + let opt_filter = + create_array!(Boolean, vec![true, true, true, true, false, false]); + foreign_accum.update_batch( + &[values], + &[0, 0, 1, 1, 2, 2], + Some(opt_filter.as_ref()), + 3, + )?; + + let groups_bool = foreign_accum.evaluate(EmitTo::All)?; + let groups_bool = groups_bool.as_any().downcast_ref::().unwrap(); - let groups_avg = foreign_accum.evaluate(EmitTo::All)?; - let groups_avg = groups_avg.as_any().downcast_ref::().unwrap(); - let expected = 4.0; - assert_eq!(groups_avg.len(), 1); - assert!((groups_avg.value(0) - expected).abs() < 0.0001); + assert_eq!( + groups_bool, + create_array!(Boolean, vec![Some(true), Some(false), None]).as_ref() + ); let state = foreign_accum.state(EmitTo::All)?; - assert_eq!(state.len(), 3); + assert_eq!(state.len(), 1); // To verify merging batches works, create a second state to add in // This should cause our average to go down to 25.0 - let second_states = vec![ - make_array(create_array!(UInt64, vec![1]).to_data()), - make_array(create_array!(Float64, vec![30.0]).to_data()), - make_array(create_array!(Float64, vec![64.0]).to_data()), - ]; + let second_states = + vec![make_array(create_array!(Boolean, vec![false]).to_data())]; let opt_filter = create_array!(Boolean, vec![true]); foreign_accum.merge_batch(&second_states, &[0], Some(opt_filter.as_ref()), 1)?; - let avg = foreign_accum.evaluate(EmitTo::All)?; - assert_eq!(avg.len(), 1); + let groups_bool = foreign_accum.evaluate(EmitTo::All)?; + assert_eq!(groups_bool.len(), 1); + assert_eq!( + groups_bool.as_ref(), + make_array(create_array!(Boolean, vec![false]).to_data()).as_ref() + ); + + let values = create_array!(Boolean, vec![false]); + let opt_filter = create_array!(Boolean, vec![true]); + let groups_bool = + foreign_accum.convert_to_state(&[values], Some(opt_filter.as_ref()))?; + assert_eq!( - avg.as_ref(), - make_array(create_array!(Float64, vec![8.0]).to_data()).as_ref() + groups_bool[0].as_ref(), + make_array(create_array!(Boolean, vec![false]).to_data()).as_ref() ); Ok(()) From 217bb8e05cc63b371fb3ca3950dc818367d18632 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Sun, 23 Feb 2025 19:58:33 +0100 Subject: [PATCH 09/32] Set feature so we do not get an error warning in stable rustc --- Cargo.toml | 2 +- datafusion/ffi/Cargo.toml | 1 + datafusion/ffi/src/arrow_wrappers.rs | 24 ++++++++++++++++++++++-- 3 files changed, 24 insertions(+), 3 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 767b66805fe4..64483eeb93da 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -216,5 +216,5 @@ unnecessary_lazy_evaluations = "warn" uninlined_format_args = "warn" [workspace.lints.rust] -unexpected_cfgs = { level = "warn", check-cfg = ["cfg(tarpaulin)"] } +unexpected_cfgs = { level = "warn", check-cfg = ["cfg(tarpaulin)", "cfg(tarpaulin_include)"] } unused_qualifications = "deny" diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index 963cfc6cd1c4..a84632b93bd5 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -58,3 +58,4 @@ doc-comment = { workspace = true } [features] integration-tests = [] +tarpaulin_include = [] # Exists only to prevent warnings on stable and still have accurate coverage diff --git a/datafusion/ffi/src/arrow_wrappers.rs b/datafusion/ffi/src/arrow_wrappers.rs index 547dd0156d9b..64bedb9cfe67 100644 --- a/datafusion/ffi/src/arrow_wrappers.rs +++ b/datafusion/ffi/src/arrow_wrappers.rs @@ -21,6 +21,7 @@ use abi_stable::StableAbi; use arrow::{ array::{make_array, ArrayRef}, datatypes::{Schema, SchemaRef}, + error::ArrowError, ffi::{from_ffi, to_ffi, FFI_ArrowArray, FFI_ArrowSchema}, }; use log::error; @@ -31,6 +32,16 @@ use log::error; #[derive(Debug, StableAbi)] pub struct WrappedSchema(#[sabi(unsafe_opaque_field)] pub FFI_ArrowSchema); +/// Some functions are expected to always succeed, like getting the schema from a TableProvider. +/// Since going through the FFI always has the potential to fail, we need to catch these errors, +/// give the user a warning, and return some kind of result. In this case we default to an +/// empty schema. +#[cfg(not(tarpaulin_include))] +fn catch_ffi_schema_error(e: ArrowError) -> FFI_ArrowSchema { + error!("Unable to convert DataFusion Schema to FFI_ArrowSchema in FFI_PlanProperties. {}", e); + FFI_ArrowSchema::empty() +} + impl From for WrappedSchema { fn from(value: SchemaRef) -> Self { let ffi_schema = match FFI_ArrowSchema::try_from(value.as_ref()) { @@ -44,6 +55,15 @@ impl From for WrappedSchema { WrappedSchema(ffi_schema) } } +/// Some functions are expected to always succeed, like getting the schema from a TableProvider. +/// Since going through the FFI always has the potential to fail, we need to catch these errors, +/// give the user a warning, and return some kind of result. In this case we default to an +/// empty schema. +#[cfg(not(tarpaulin_include))] +fn catch_df_schema_error(e: ArrowError) -> Schema { + error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {}", e); + Schema::empty() +} impl From for SchemaRef { fn from(value: WrappedSchema) -> Self { @@ -71,7 +91,7 @@ pub struct WrappedArray { } impl TryFrom for ArrayRef { - type Error = arrow::error::ArrowError; + type Error = ArrowError; fn try_from(value: WrappedArray) -> Result { let data = unsafe { from_ffi(value.array, &value.schema.0)? }; @@ -81,7 +101,7 @@ impl TryFrom for ArrayRef { } impl TryFrom<&ArrayRef> for WrappedArray { - type Error = arrow::error::ArrowError; + type Error = ArrowError; fn try_from(array: &ArrayRef) -> Result { let (array, schema) = to_ffi(&array.to_data())?; From b5b11d403b22317f0c5eb13eb6849d13739d9786 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 24 Feb 2025 08:21:33 +0100 Subject: [PATCH 10/32] Add more options to test --- datafusion/ffi/src/plan_properties.rs | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/datafusion/ffi/src/plan_properties.rs b/datafusion/ffi/src/plan_properties.rs index 5c878fa4be79..587e667a4775 100644 --- a/datafusion/ffi/src/plan_properties.rs +++ b/datafusion/ffi/src/plan_properties.rs @@ -300,7 +300,10 @@ impl From for EmissionType { #[cfg(test)] mod tests { - use datafusion::physical_plan::Partitioning; + use datafusion::{ + physical_expr::{LexOrdering, PhysicalSortExpr}, + physical_plan::Partitioning, + }; use super::*; @@ -311,8 +314,13 @@ mod tests { Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); let original_props = PlanProperties::new( - EquivalenceProperties::new(schema), - Partitioning::UnknownPartitioning(3), + EquivalenceProperties::new(Arc::clone(&schema)).with_reorder( + LexOrdering::new(vec![PhysicalSortExpr { + expr: datafusion::physical_plan::expressions::col("a", &schema)?, + options: Default::default(), + }]), + ), + Partitioning::RoundRobinBatch(3), EmissionType::Incremental, Boundedness::Bounded, ); From ae61d88f459079c747b9f3e12dc172dd4fcbd679 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 25 Feb 2025 06:54:03 +0100 Subject: [PATCH 11/32] Add unit test for FFI RecordBatchStream --- datafusion/ffi/src/record_batch_stream.rs | 45 +++++++++++++++++++++++ datafusion/ffi/src/util.rs | 2 +- 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/datafusion/ffi/src/record_batch_stream.rs b/datafusion/ffi/src/record_batch_stream.rs index 939c4050028c..5663fb12f0e9 100644 --- a/datafusion/ffi/src/record_batch_stream.rs +++ b/datafusion/ffi/src/record_batch_stream.rs @@ -196,3 +196,48 @@ impl Stream for FFI_RecordBatchStream { } } } + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion::{ + common::record_batch, error::Result, execution::SendableRecordBatchStream, + test_util::bounded_stream, + }; + + use super::FFI_RecordBatchStream; + use futures::StreamExt; + + #[tokio::test] + async fn test_round_trip_record_batch_stream() -> Result<()> { + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 3]), + ("b", Float64, vec![Some(4.0), None, Some(5.0)]) + )?; + let original_rbs = bounded_stream(record_batch.clone(), 1); + + let ffi_rbs: FFI_RecordBatchStream = original_rbs.into(); + let mut ffi_rbs: SendableRecordBatchStream = Box::pin(ffi_rbs); + + let schema = ffi_rbs.schema(); + assert_eq!( + schema, + Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Float64, true) + ])) + ); + + let batch = ffi_rbs.next().await; + assert!(batch.is_some()); + assert!(batch.unwrap().is_ok()); + + // There should only be one batch + let no_batch = ffi_rbs.next().await; + assert!(no_batch.is_none()); + + Ok(()) + } +} diff --git a/datafusion/ffi/src/util.rs b/datafusion/ffi/src/util.rs index 3eb57963b44f..abe369c57298 100644 --- a/datafusion/ffi/src/util.rs +++ b/datafusion/ffi/src/util.rs @@ -22,7 +22,7 @@ use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; use arrow_schema::FieldRef; use std::sync::Arc; -/// This macro is a helpful conversion utility to conver from an abi_stable::RResult to a +/// This macro is a helpful conversion utility to convert from an abi_stable::RResult to a /// DataFusion result. #[macro_export] macro_rules! df_result { From dfd3268758415194864180bfe6efcec44b38473a Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 25 Feb 2025 08:57:16 +0100 Subject: [PATCH 12/32] Add a few more args to ffi accumulator test fn --- datafusion/ffi/src/udaf/accumulator_args.rs | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs index b3933c267025..3a25d09c4a55 100644 --- a/datafusion/ffi/src/udaf/accumulator_args.rs +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -156,23 +156,29 @@ impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> { #[cfg(test)] mod tests { use super::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; - use arrow::datatypes::{DataType, Schema}; + use arrow::datatypes::{DataType, Field, Schema}; use datafusion::{ - error::Result, logical_expr::function::AccumulatorArgs, - physical_expr::LexOrdering, + error::Result, + logical_expr::function::AccumulatorArgs, + physical_expr::{LexOrdering, PhysicalSortExpr}, + physical_plan::expressions::col, }; #[test] fn test_round_trip_accumulator_args() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); let orig_args = AccumulatorArgs { return_type: &DataType::Float64, - schema: &Schema::empty(), + schema: &schema, ignore_nulls: false, - ordering_req: &LexOrdering::new(vec![]), + ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: Default::default(), + }]), is_reversed: false, name: "round_trip", is_distinct: true, - exprs: &[], + exprs: &[col("a", &schema)?], }; let orig_str = format!("{:?}", orig_args); @@ -185,6 +191,7 @@ mod tests { // Since AccumulatorArgs doesn't implement Eq, simply compare // the debug strings. assert_eq!(orig_str, round_trip_str); + println!("{}", round_trip_str); Ok(()) } From aa4b7ceafbc935342a1b55dcffb41e04e05d7d1d Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 27 Feb 2025 07:59:48 +0100 Subject: [PATCH 13/32] Adding more unit tests on ffi aggregate udaf --- datafusion/ffi/src/udaf/accumulator.rs | 12 +- datafusion/ffi/src/udaf/mod.rs | 168 ++++++++++++++++++++++--- 2 files changed, 160 insertions(+), 20 deletions(-) diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs index c4f8edfeafc5..a6c007dce8f7 100644 --- a/datafusion/ffi/src/udaf/accumulator.rs +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -307,7 +307,11 @@ mod tests { #[test] fn test_foreign_avg_accumulator() -> Result<()> { - let boxed_accum: Box = Box::new(AvgAccumulator::default()); + let original_accum = AvgAccumulator::default(); + let original_size = original_accum.size(); + let original_supports_retract = original_accum.supports_retract_batch(); + + let boxed_accum: Box = Box::new(original_accum); let ffi_accum: FFI_Accumulator = boxed_accum.into(); let mut foreign_accum: ForeignAccumulator = ffi_accum.into(); @@ -341,6 +345,12 @@ mod tests { let avg = foreign_accum.evaluate()?; assert_eq!(avg, ScalarValue::Float64(Some(30.0))); + assert_eq!(original_size, foreign_accum.size()); + assert_eq!( + original_supports_retract, + foreign_accum.supports_retract_batch() + ); + Ok(()) } } diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index 018b2b07c2eb..cec7cf08a3e6 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -528,25 +528,6 @@ impl AggregateUDFImpl for ForeignAggregateUDF { } } -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_round_trip_udaf() -> Result<()> { - let original_udaf = datafusion::functions_aggregate::sum::Sum::new(); - let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); - - let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); - - let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; - - assert!(original_udaf.name() == foreign_udaf.name()); - - Ok(()) - } -} - #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] @@ -575,3 +556,152 @@ impl From for FFI_AggregateOrderSensitivity { } } } + +#[cfg(test)] +mod tests { + use arrow::datatypes::Schema; + use datafusion::{ + common::create_array, + functions_aggregate::sum::Sum, + physical_expr::{LexOrdering, PhysicalSortExpr}, + physical_plan::expressions::col, + scalar::ScalarValue, + }; + + use super::*; + + fn create_test_foreign_udaf( + original_udaf: impl AggregateUDFImpl + 'static, + ) -> Result { + let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + + let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + + let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; + Ok(foreign_udaf.into()) + } + + #[test] + fn test_round_trip_udaf() -> Result<()> { + let original_udaf = Sum::new(); + let original_name = original_udaf.name().to_owned(); + + let foreign_udaf = create_test_foreign_udaf(original_udaf)?; + // let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); + + // let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + + // let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; + // let foreign_udaf: AggregateUDF = foreign_udaf.into(); + + assert_eq!(original_name, foreign_udaf.name()); + Ok(()) + } + + #[test] + fn test_foreign_udaf_aliases() -> Result<()> { + let foreign_udaf = + create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]); + + let return_type = foreign_udaf.return_type(&[DataType::Float64])?; + assert_eq!(return_type, DataType::Float64); + Ok(()) + } + + #[test] + fn test_foreign_udaf_accumulator() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf(Sum::new())?; + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + let acc_args = AccumulatorArgs { + return_type: &DataType::Float64, + schema: &schema, + ignore_nulls: true, + ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: Default::default(), + }]), + is_reversed: false, + name: "round_trip", + is_distinct: true, + exprs: &[col("a", &schema)?], + }; + let mut accumulator = foreign_udaf.accumulator(acc_args)?; + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + accumulator.update_batch(&[values])?; + let resultant_value = accumulator.evaluate()?; + assert_eq!(resultant_value, ScalarValue::Float64(Some(150.))); + + Ok(()) + } + + #[test] + fn test_beneficial_ordering() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf( + datafusion::functions_aggregate::first_last::FirstValue::new(), + )?; + + let foreign_udaf = foreign_udaf.with_beneficial_ordering(true)?.unwrap(); + + assert_eq!( + foreign_udaf.order_sensitivity(), + AggregateOrderSensitivity::Beneficial + ); + + let a_field = Field::new("a", DataType::Float64, true); + let state_fields = foreign_udaf.state_fields(StateFieldsArgs { + name: "a", + input_types: &[DataType::Float64], + return_type: &DataType::Float64, + ordering_fields: &[a_field.clone()], + is_distinct: false, + })?; + + println!("{:#?}", state_fields); + assert_eq!(state_fields.len(), 3); + assert_eq!(state_fields[1], a_field); + Ok(()) + } + + #[test] + fn test_sliding_accumulator() -> Result<()> { + let foreign_udaf = create_test_foreign_udaf(Sum::new())?; + + let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); + let acc_args = AccumulatorArgs { + return_type: &DataType::Float64, + schema: &schema, + ignore_nulls: true, + ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { + expr: col("a", &schema)?, + options: Default::default(), + }]), + is_reversed: false, + name: "round_trip", + is_distinct: true, + exprs: &[col("a", &schema)?], + }; + + let mut accumulator = foreign_udaf.create_sliding_accumulator(acc_args)?; + let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]); + accumulator.update_batch(&[values])?; + let resultant_value = accumulator.evaluate()?; + assert_eq!(resultant_value, ScalarValue::Float64(Some(150.))); + + Ok(()) + } + + fn test_round_trip_order_sensitivity(sensitivity: AggregateOrderSensitivity) { + let ffi_sensitivity: FFI_AggregateOrderSensitivity = sensitivity.into(); + let round_trip_sensitivity: AggregateOrderSensitivity = ffi_sensitivity.into(); + + assert_eq!(sensitivity, round_trip_sensitivity); + } + + #[test] + fn test_round_trip_all_order_sensitivities() { + test_round_trip_order_sensitivity(AggregateOrderSensitivity::Insensitive); + test_round_trip_order_sensitivity(AggregateOrderSensitivity::HardRequirement); + test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial); + } +} From 2de6d0a5a654d445f1e1aefb9a6da98cc97be7ad Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 27 Feb 2025 08:14:37 +0100 Subject: [PATCH 14/32] taplo format --- datafusion/ffi/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index a84632b93bd5..a8335769ec29 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -44,9 +44,9 @@ arrow-schema = { workspace = true } async-ffi = { version = "0.5.0", features = ["abi_stable"] } async-trait = { workspace = true } datafusion = { workspace = true, default-features = false } +datafusion-functions-aggregate-common = { workspace = true } datafusion-proto = { workspace = true } datafusion-proto-common = { workspace = true } -datafusion-functions-aggregate-common = { workspace = true } futures = { workspace = true } log = { workspace = true } prost = { workspace = true } From 11f88de866a55ecf3c80e076159aabccac96e553 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 27 Feb 2025 08:15:49 +0100 Subject: [PATCH 15/32] Update code comment --- datafusion/ffi/src/udaf/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index cec7cf08a3e6..ac59cef8fdf4 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -116,7 +116,7 @@ pub struct FFI_AggregateUDF { /// Performs type coersion. To simply this interface, all UDFs are treated as having /// user defined signatures, which will in turn call coerce_types to be called. This /// call should be transparent to most users as the internal function performs the - /// appropriate calls on the underlying [`ScalarUDF`] + /// appropriate calls on the underlying [`AggregateUDF`] pub coerce_types: unsafe extern "C" fn( udf: &Self, arg_types: RVec, From 9300ba57bbe63546d32be768cdfcadc2da2f62b6 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 27 Feb 2025 21:50:39 +0100 Subject: [PATCH 16/32] Correct function name --- datafusion/ffi/src/tests/mod.rs | 4 ++-- datafusion/ffi/src/tests/udf_udaf_udwf.rs | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index 8605a84a3c8d..7a854fd1c33e 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -39,7 +39,7 @@ use datafusion::{ common::record_batch, }; use sync_provider::create_sync_table_provider; -use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_avg_func, create_ffi_random_func, create_ffi_stddev_func, create_ffi_table_func}; +use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_random_func, create_ffi_stddev_func, create_ffi_sum_func, create_ffi_table_func}; mod async_provider; pub mod catalog; @@ -120,7 +120,7 @@ pub fn get_foreign_library_module() -> ForeignLibraryModuleRef { create_scalar_udf: create_ffi_abs_func, create_nullary_udf: create_ffi_random_func, create_table_function: create_ffi_table_func, - create_sum_udaf: create_ffi_avg_func, + create_sum_udaf: create_ffi_sum_func, create_stddev_udaf: create_ffi_stddev_func, version: super::version, } diff --git a/datafusion/ffi/src/tests/udf_udaf_udwf.rs b/datafusion/ffi/src/tests/udf_udaf_udwf.rs index 73caee695b33..b2a4194fa355 100644 --- a/datafusion/ffi/src/tests/udf_udaf_udwf.rs +++ b/datafusion/ffi/src/tests/udf_udaf_udwf.rs @@ -44,7 +44,8 @@ pub(crate) extern "C" fn create_ffi_table_func() -> FFI_TableFunction { FFI_TableFunction::new(udtf, None) } -pub(crate) extern "C" fn create_ffi_avg_func() -> FFI_AggregateUDF { + +pub(crate) extern "C" fn create_ffi_sum_func() -> FFI_AggregateUDF { let udaf: Arc = Arc::new(Sum::new().into()); udaf.into() From ec05091983164774503041ffa9d59b756a45a6fc Mon Sep 17 00:00:00 2001 From: Crystal Zhou Date: Tue, 1 Apr 2025 10:49:58 -0400 Subject: [PATCH 17/32] Temp fix record batch test dependencies --- .../src/datasource/file_format/parquet.rs | 140 +++++++++++++----- datafusion/core/src/lib.rs | 1 + datafusion/core/src/test_util/mod.rs | 47 ++++++ 3 files changed, 152 insertions(+), 36 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 9705225c24c7..ab1e938336d4 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -107,10 +107,8 @@ pub(crate) mod test_util { mod tests { use std::fmt::{self, Display, Formatter}; - use std::pin::Pin; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; - use std::task::{Context, Poll}; use std::time::Duration; use crate::datasource::file_format::parquet::test_util::store_parquet; @@ -120,7 +118,7 @@ mod tests { use crate::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use arrow::array::RecordBatch; - use arrow_schema::{Schema, SchemaRef}; + use arrow_schema::Schema; use datafusion_catalog::Session; use datafusion_common::cast::{ as_binary_array, as_binary_view_array, as_boolean_array, as_float32_array, @@ -140,7 +138,7 @@ mod tests { }; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_execution::runtime_env::RuntimeEnv; - use datafusion_execution::{RecordBatchStream, TaskContext}; + use datafusion_execution::TaskContext; use datafusion_expr::dml::InsertOp; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use datafusion_physical_plan::{collect, ExecutionPlan}; @@ -153,7 +151,7 @@ mod tests { use async_trait::async_trait; use datafusion_datasource::file_groups::FileGroup; use futures::stream::BoxStream; - use futures::{Stream, StreamExt}; + use futures::StreamExt; use insta::assert_snapshot; use log::error; use object_store::local::LocalFileSystem; @@ -169,6 +167,8 @@ mod tests { use parquet::format::FileMetaData; use tokio::fs::File; + use crate::test_util::bounded_stream; + enum ForceViews { Yes, No, @@ -1663,42 +1663,110 @@ mod tests { Ok(()) } - /// Creates an bounded stream for testing purposes. - fn bounded_stream( - batch: RecordBatch, - limit: usize, - ) -> datafusion_execution::SendableRecordBatchStream { - Box::pin(BoundedStream { - count: 0, - limit, - batch, - }) - } + #[tokio::test] + async fn test_memory_reservation_column_parallel() -> Result<()> { + async fn test_memory_reservation(global: ParquetOptions) -> Result<()> { + let field_a = Field::new("a", DataType::Utf8, false); + let field_b = Field::new("b", DataType::Utf8, false); + let schema = Arc::new(Schema::new(vec![field_a, field_b])); + let object_store_url = ObjectStoreUrl::local_filesystem(); - struct BoundedStream { - limit: usize, - count: usize, - batch: RecordBatch, - } + let file_sink_config = FileSinkConfig { + original_url: String::default(), + object_store_url: object_store_url.clone(), + file_group: FileGroup::new(vec![PartitionedFile::new( + "/tmp".to_string(), + 1, + )]), + table_paths: vec![ListingTableUrl::parse("file:///")?], + output_schema: schema.clone(), + table_partition_cols: vec![], + insert_op: InsertOp::Overwrite, + keep_partition_by_columns: false, + file_extension: "parquet".into(), + }; + let parquet_sink = Arc::new(ParquetSink::new( + file_sink_config, + TableParquetOptions { + key_value_metadata: std::collections::HashMap::from([ + ("my-data".to_string(), Some("stuff".to_string())), + ("my-data-bool-key".to_string(), None), + ]), + global, + ..Default::default() + }, + )); + + // create data + let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); + let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"])); + let batch = + RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); - impl Stream for BoundedStream { - type Item = Result; + // create task context + let task_context = build_ctx(object_store_url.as_ref()); + assert_eq!( + task_context.memory_pool().reserved(), + 0, + "no bytes are reserved yet" + ); - fn poll_next( - mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, - ) -> Poll> { - if self.count >= self.limit { - return Poll::Ready(None); + let mut write_task = FileSink::write_all( + parquet_sink.as_ref(), + Box::pin(RecordBatchStreamAdapter::new( + schema, + bounded_stream(batch, 1000), + )), + &task_context, + ); + + // incrementally poll and check for memory reservation + let mut reserved_bytes = 0; + while futures::poll!(&mut write_task).is_pending() { + reserved_bytes += task_context.memory_pool().reserved(); + tokio::time::sleep(Duration::from_micros(1)).await; } - self.count += 1; - Poll::Ready(Some(Ok(self.batch.clone()))) - } - } + assert!( + reserved_bytes > 0, + "should have bytes reserved during write" + ); + assert_eq!( + task_context.memory_pool().reserved(), + 0, + "no leaking byte reservation" + ); - impl RecordBatchStream for BoundedStream { - fn schema(&self) -> SchemaRef { - self.batch.schema() + Ok(()) } + + let write_opts = ParquetOptions { + allow_single_file_parallelism: false, + ..Default::default() + }; + test_memory_reservation(write_opts) + .await + .expect("should track for non-parallel writes"); + + let row_parallel_write_opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 10, + maximum_buffered_record_batches_per_stream: 1, + ..Default::default() + }; + test_memory_reservation(row_parallel_write_opts) + .await + .expect("should track for row-parallel writes"); + + let col_parallel_write_opts = ParquetOptions { + allow_single_file_parallelism: true, + maximum_parallel_row_group_writers: 1, + maximum_buffered_record_batches_per_stream: 2, + ..Default::default() + }; + test_memory_reservation(col_parallel_write_opts) + .await + .expect("should track for column-parallel writes"); + + Ok(()) } } diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index 6956108e2df3..bce08231ba62 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -843,6 +843,7 @@ pub mod test; mod schema_equivalence; pub mod test_util; +// pub use test_util::bounded_stream; #[cfg(doctest)] doc_comment::doctest!("../../../README.md", readme_example_test); diff --git a/datafusion/core/src/test_util/mod.rs b/datafusion/core/src/test_util/mod.rs index d6865ca3d532..2f8e66a2bbfb 100644 --- a/datafusion/core/src/test_util/mod.rs +++ b/datafusion/core/src/test_util/mod.rs @@ -22,12 +22,14 @@ pub mod parquet; pub mod csv; +use futures::Stream; use std::any::Any; use std::collections::HashMap; use std::fs::File; use std::io::Write; use std::path::Path; use std::sync::Arc; +use std::task::{Context, Poll}; use crate::catalog::{TableProvider, TableProviderFactory}; use crate::dataframe::DataFrame; @@ -38,11 +40,13 @@ use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use crate::physical_plan::ExecutionPlan; use crate::prelude::{CsvReadOptions, SessionContext}; +use crate::execution::SendableRecordBatchStream; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use datafusion_catalog::Session; use datafusion_common::TableReference; use datafusion_expr::{CreateExternalTable, Expr, SortExpr, TableType}; +use std::pin::Pin; use async_trait::async_trait; @@ -52,6 +56,8 @@ use tempfile::TempDir; pub use datafusion_common::test_util::parquet_test_data; pub use datafusion_common::test_util::{arrow_test_data, get_data_dir}; +use crate::execution::RecordBatchStream; + /// Scan an empty data source, mainly used in tests pub fn scan_empty( name: Option<&str>, @@ -234,3 +240,44 @@ pub fn register_unbounded_file_with_ordering( ctx.register_table(table_name, Arc::new(StreamTable::new(Arc::new(config))))?; Ok(()) } + +/// Creates a bounded stream that emits the same record batch a specified number of times. +/// This is useful for testing purposes. +pub fn bounded_stream( + record_batch: RecordBatch, + limit: usize, +) -> SendableRecordBatchStream { + Box::pin(BoundedStream { + record_batch, + count: 0, + limit, + }) +} + +struct BoundedStream { + record_batch: RecordBatch, + count: usize, + limit: usize, +} + +impl Stream for BoundedStream { + type Item = Result; + + fn poll_next( + mut self: Pin<&mut Self>, + _cx: &mut Context<'_>, + ) -> Poll> { + if self.count >= self.limit { + Poll::Ready(None) + } else { + self.count += 1; + Poll::Ready(Some(Ok(self.record_batch.clone()))) + } + } +} + +impl RecordBatchStream for BoundedStream { + fn schema(&self) -> SchemaRef { + self.record_batch.schema() + } +} From fc40bc0c0f4a9c6314e164182ddfe19418e18ddb Mon Sep 17 00:00:00 2001 From: Crystal Zhou Date: Tue, 1 Apr 2025 20:36:57 -0400 Subject: [PATCH 18/32] Address some comments --- datafusion/ffi/src/record_batch_stream.rs | 3 +- datafusion/ffi/src/udaf/accumulator.rs | 37 +++++++++++++---------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/datafusion/ffi/src/record_batch_stream.rs b/datafusion/ffi/src/record_batch_stream.rs index 5663fb12f0e9..78d65a816fcc 100644 --- a/datafusion/ffi/src/record_batch_stream.rs +++ b/datafusion/ffi/src/record_batch_stream.rs @@ -232,7 +232,8 @@ mod tests { let batch = ffi_rbs.next().await; assert!(batch.is_some()); - assert!(batch.unwrap().is_ok()); + assert!(batch.as_ref().unwrap().is_ok()); + assert_eq!(batch.unwrap().unwrap(), record_batch); // There should only be one batch let no_batch = ffi_rbs.next().await; diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs index a6c007dce8f7..897cd9f49cc3 100644 --- a/datafusion/ffi/src/udaf/accumulator.rs +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -31,6 +31,7 @@ use prost::Message; use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; +/// A stable struct for sharing [`Accumulator`] across FFI boundaries. #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] @@ -75,12 +76,19 @@ pub struct AccumulatorPrivateData { pub accumulator: Box, } +impl FFI_Accumulator { + #[inline] + unsafe fn inner(&self) -> &mut AccumulatorPrivateData { + let private_data = self.private_data as *mut AccumulatorPrivateData; + &mut (*private_data) + } +} + unsafe extern "C" fn update_batch_fn_wrapper( accumulator: &mut FFI_Accumulator, values: RVec, ) -> RResult<(), RString> { - let private_data = accumulator.private_data as *mut AccumulatorPrivateData; - let accum_data = &mut (*private_data); + let accum_data = accumulator.inner(); let values_arrays = values .into_iter() @@ -94,8 +102,7 @@ unsafe extern "C" fn update_batch_fn_wrapper( unsafe extern "C" fn evaluate_fn_wrapper( accumulator: &FFI_Accumulator, ) -> RResult, RString> { - let private_data = accumulator.private_data as *mut AccumulatorPrivateData; - let accum_data = &mut (*private_data); + let accum_data = accumulator.inner(); let scalar_result = rresult_return!(accum_data.accumulator.evaluate()); let proto_result: datafusion_proto::protobuf::ScalarValue = @@ -105,10 +112,9 @@ unsafe extern "C" fn evaluate_fn_wrapper( } unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_Accumulator) -> usize { - let private_data = accumulator.private_data as *mut AccumulatorPrivateData; - let accum_data = &mut (*private_data); - - accum_data.accumulator.size() + // let private_data = accumulator.private_data as *mut AccumulatorPrivateData; + // let accum_data = &mut (*private_data); + accumulator.inner().accumulator.size() } unsafe extern "C" fn state_fn_wrapper( @@ -135,8 +141,7 @@ unsafe extern "C" fn merge_batch_fn_wrapper( accumulator: &mut FFI_Accumulator, states: RVec, ) -> RResult<(), RString> { - let private_data = accumulator.private_data as *mut AccumulatorPrivateData; - let accum_data = &mut (*private_data); + let accum_data = accumulator.inner(); let states = rresult_return!(states .into_iter() @@ -150,15 +155,15 @@ unsafe extern "C" fn retract_batch_fn_wrapper( accumulator: &mut FFI_Accumulator, values: RVec, ) -> RResult<(), RString> { - let private_data = accumulator.private_data as *mut AccumulatorPrivateData; - let accum_data = &mut (*private_data); + let accum_data = accumulator.inner(); - let values = rresult_return!(values + let values_arrays = values .into_iter() - .map(|state| ArrayRef::try_from(state).map_err(DataFusionError::from)) - .collect::>>()); + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>(); + let values_arrays = rresult_return!(values_arrays); - rresult!(accum_data.accumulator.retract_batch(&values)) + rresult!(accum_data.accumulator.retract_batch(&values_arrays)) } unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) { From 4d164cc77a6561efbd5ce0dc2be1e5688676cb92 Mon Sep 17 00:00:00 2001 From: Crystal Zhou Date: Tue, 8 Apr 2025 18:34:40 -0400 Subject: [PATCH 19/32] Revise comments and address PR comments --- datafusion/ffi/src/udaf/accumulator.rs | 31 +++++++++++-------------- datafusion/ffi/src/udaf/mod.rs | 26 ++++++++++++++------- datafusion/ffi/tests/ffi_integration.rs | 10 +++----- 3 files changed, 34 insertions(+), 33 deletions(-) diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs index 897cd9f49cc3..9178c0656041 100644 --- a/datafusion/ffi/src/udaf/accumulator.rs +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -78,9 +78,9 @@ pub struct AccumulatorPrivateData { impl FFI_Accumulator { #[inline] - unsafe fn inner(&self) -> &mut AccumulatorPrivateData { + unsafe fn inner(&self) -> &mut Box { let private_data = self.private_data as *mut AccumulatorPrivateData; - &mut (*private_data) + &mut (*private_data).accumulator } } @@ -88,7 +88,7 @@ unsafe extern "C" fn update_batch_fn_wrapper( accumulator: &mut FFI_Accumulator, values: RVec, ) -> RResult<(), RString> { - let accum_data = accumulator.inner(); + let accumulator = accumulator.inner(); let values_arrays = values .into_iter() @@ -96,15 +96,15 @@ unsafe extern "C" fn update_batch_fn_wrapper( .collect::>>(); let values_arrays = rresult_return!(values_arrays); - rresult!(accum_data.accumulator.update_batch(&values_arrays)) + rresult!(accumulator.update_batch(&values_arrays)) } unsafe extern "C" fn evaluate_fn_wrapper( accumulator: &FFI_Accumulator, ) -> RResult, RString> { - let accum_data = accumulator.inner(); + let accumulator = accumulator.inner(); - let scalar_result = rresult_return!(accum_data.accumulator.evaluate()); + let scalar_result = rresult_return!(accumulator.evaluate()); let proto_result: datafusion_proto::protobuf::ScalarValue = rresult_return!((&scalar_result).try_into()); @@ -112,18 +112,15 @@ unsafe extern "C" fn evaluate_fn_wrapper( } unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_Accumulator) -> usize { - // let private_data = accumulator.private_data as *mut AccumulatorPrivateData; - // let accum_data = &mut (*private_data); - accumulator.inner().accumulator.size() + accumulator.inner().size() } unsafe extern "C" fn state_fn_wrapper( accumulator: &FFI_Accumulator, ) -> RResult>, RString> { - let private_data = accumulator.private_data as *mut AccumulatorPrivateData; - let accum_data = &mut (*private_data); + let accumulator = accumulator.inner(); - let state = rresult_return!(accum_data.accumulator.state()); + let state = rresult_return!(accumulator.state()); let state = state .into_iter() .map(|state_val| { @@ -141,21 +138,21 @@ unsafe extern "C" fn merge_batch_fn_wrapper( accumulator: &mut FFI_Accumulator, states: RVec, ) -> RResult<(), RString> { - let accum_data = accumulator.inner(); + let accumulator = accumulator.inner(); let states = rresult_return!(states .into_iter() .map(|state| ArrayRef::try_from(state).map_err(DataFusionError::from)) .collect::>>()); - rresult!(accum_data.accumulator.merge_batch(&states)) + rresult!(accumulator.merge_batch(&states)) } unsafe extern "C" fn retract_batch_fn_wrapper( accumulator: &mut FFI_Accumulator, values: RVec, ) -> RResult<(), RString> { - let accum_data = accumulator.inner(); + let accumulator = accumulator.inner(); let values_arrays = values .into_iter() @@ -163,7 +160,7 @@ unsafe extern "C" fn retract_batch_fn_wrapper( .collect::>>(); let values_arrays = rresult_return!(values_arrays); - rresult!(accum_data.accumulator.retract_batch(&values_arrays)) + rresult!(accumulator.retract_batch(&values_arrays)) } unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_Accumulator) { @@ -185,8 +182,6 @@ impl From> for FFI_Accumulator { merge_batch: merge_batch_fn_wrapper, retract_batch: retract_batch_fn_wrapper, supports_retract_batch, - - // clone: clone_fn_wrapper, release: release_fn_wrapper, private_data: Box::into_raw(Box::new(private_data)) as *mut c_void, } diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index ac59cef8fdf4..544081c557dd 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -61,33 +61,40 @@ pub struct FFI_AggregateUDF { /// FFI equivalent to the `name` of a [`AggregateUDF`] pub name: RString, - /// FFI equivalent to the `name` of a [`AggregateUDF`] + /// FFI equivalent to the `aliases` of a [`AggregateUDF`] pub aliases: RVec, - /// FFI equivalent to the `name` of a [`AggregateUDF`] + /// FFI equivalent to the `volatility` of a [`AggregateUDF`] pub volatility: FFI_Volatility, + /// Determines the return type of the underlying [`AggregateUDF`] based on the + /// argument types. pub return_type: unsafe extern "C" fn( udaf: &Self, arg_types: RVec, ) -> RResult, + /// FFI equivalent to the `is_nullable` of a [`AggregateUDF`] pub is_nullable: bool, + /// FFI equivalent to [`AggregateUDF::groups_accumulator_supported`] pub groups_accumulator_supported: unsafe extern "C" fn(udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs) -> bool, + /// FFI equivalent to [`AggregateUDF::accumulator`] pub accumulator: unsafe extern "C" fn( udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs, ) -> RResult, + /// FFI equivalent to [`AggregateUDF::create_sliding_accumulator`] pub create_sliding_accumulator: unsafe extern "C" fn( udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs, ) -> RResult, + /// FFI equivalent to [`AggregateUDF::state_fields`] #[allow(clippy::type_complexity)] pub state_fields: unsafe extern "C" fn( udaf: &FFI_AggregateUDF, @@ -98,18 +105,21 @@ pub struct FFI_AggregateUDF { is_distinct: bool, ) -> RResult>, RString>, + /// FFI equivalent to [`AggregateUDF::create_groups_accumulator`] pub create_groups_accumulator: unsafe extern "C" fn( udaf: &FFI_AggregateUDF, args: FFI_AccumulatorArgs, ) -> RResult, + /// FFI equivalent to [`AggregateUDF::with_beneficial_ordering`] pub with_beneficial_ordering: unsafe extern "C" fn( udaf: &FFI_AggregateUDF, beneficial_ordering: bool, ) -> RResult, RString>, + /// FFI equivalent to [`AggregateUDF::order_sensitivity`] pub order_sensitivity: unsafe extern "C" fn(udaf: &FFI_AggregateUDF) -> FFI_AggregateOrderSensitivity, @@ -585,14 +595,14 @@ mod tests { fn test_round_trip_udaf() -> Result<()> { let original_udaf = Sum::new(); let original_name = original_udaf.name().to_owned(); + let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); - let foreign_udaf = create_test_foreign_udaf(original_udaf)?; - // let original_udaf = Arc::new(AggregateUDF::from(original_udaf)); - - // let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); + // Convert to FFI format + let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into(); - // let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; - // let foreign_udaf: AggregateUDF = foreign_udaf.into(); + // Convert back to native format + let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?; + let foreign_udaf: AggregateUDF = foreign_udaf.into(); assert_eq!(original_name, foreign_udaf.name()); Ok(()) diff --git a/datafusion/ffi/tests/ffi_integration.rs b/datafusion/ffi/tests/ffi_integration.rs index 3ea65caadf02..af122438dd63 100644 --- a/datafusion/ffi/tests/ffi_integration.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -19,8 +19,6 @@ /// when the feature integtation-tests is built #[cfg(feature = "integration-tests")] mod tests { - - use abi_stable::library::RootModule; use arrow::array::Float64Array; use datafusion::common::record_batch; use datafusion::error::{DataFusionError, Result}; @@ -31,8 +29,6 @@ mod tests { use datafusion_ffi::tests::utils::get_module; use datafusion::logical_expr::{AggregateUDF, ScalarUDF}; use datafusion_ffi::udaf::ForeignAggregateUDF; - use datafusion_ffi::udf::ForeignScalarUDF; - use std::path::Path; use std::sync::Arc; /// It is important that this test is in the `tests` directory and not in the @@ -108,15 +104,15 @@ mod tests { async fn test_ffi_udaf() -> Result<()> { let module = get_module()?; - let ffi_avg_func = + let ffi_sum_func = module .create_sum_udaf() .ok_or(DataFusionError::NotImplemented( "External table provider failed to implement create_udaf".to_string(), ))?(); - let foreign_avg_func: ForeignAggregateUDF = (&ffi_avg_func).try_into()?; + let foreign_sum_func: ForeignAggregateUDF = (&ffi_sum_func).try_into()?; - let udaf: AggregateUDF = foreign_avg_func.into(); + let udaf: AggregateUDF = foreign_sum_func.into(); let ctx = SessionContext::default(); let record_batch = record_batch!( From 173d7c09767c32a4c8e3e5696bd3b0ffdafd5e9c Mon Sep 17 00:00:00 2001 From: Crystal Zhou Date: Tue, 8 Apr 2025 18:37:52 -0400 Subject: [PATCH 20/32] Remove commented code --- datafusion/ffi/src/udaf/groups_accumulator.rs | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs index 3f6b5def4f9b..fed8d9d14b1e 100644 --- a/datafusion/ffi/src/udaf/groups_accumulator.rs +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -238,19 +238,6 @@ unsafe extern "C" fn release_fn_wrapper(accumulator: &mut FFI_GroupsAccumulator) drop(private_data); } -// unsafe extern "C" fn clone_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> FFI_GroupsAccumulator { -// let private_data = accumulator.private_data as *const GroupsAccumulatorPrivateData; -// let accum_data = &(*private_data); - -// Box::new(accum_data.accumulator).into() -// } - -// impl Clone for FFI_GroupsAccumulator { -// fn clone(&self) -> Self { -// unsafe { (self.clone)(self) } -// } -// } - impl From> for FFI_GroupsAccumulator { fn from(accumulator: Box) -> Self { let supports_convert_to_state = accumulator.supports_convert_to_state(); From 45ea2835b5428b6df6d78891ba1a2a642aab1b5c Mon Sep 17 00:00:00 2001 From: Crystal Zhou Date: Tue, 8 Apr 2025 18:40:26 -0400 Subject: [PATCH 21/32] Refactor GroupsAccumulator --- datafusion/ffi/src/udaf/groups_accumulator.rs | 42 +++++++++---------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs index fed8d9d14b1e..77878403741b 100644 --- a/datafusion/ffi/src/udaf/groups_accumulator.rs +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -93,6 +93,13 @@ pub struct GroupsAccumulatorPrivateData { pub accumulator: Box, } +impl FFI_GroupsAccumulator { + unsafe fn inner(&self) -> &mut Box { + let private_data = self.private_data as *mut GroupsAccumulatorPrivateData; + &mut (*private_data).accumulator + } +} + unsafe extern "C" fn update_batch_fn_wrapper( accumulator: &mut FFI_GroupsAccumulator, values: RVec, @@ -100,8 +107,7 @@ unsafe extern "C" fn update_batch_fn_wrapper( opt_filter: ROption, total_num_groups: usize, ) -> RResult<(), RString> { - let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; - let accum_data = &mut (*private_data); + let accumulator = accumulator.inner(); let values_arrays = values .into_iter() @@ -122,7 +128,7 @@ unsafe extern "C" fn update_batch_fn_wrapper( }).map(|arr| arr.into_data()); let opt_filter = maybe_filter.map(BooleanArray::from); - rresult!(accum_data.accumulator.update_batch( + rresult!(accumulator.update_batch( &values_arrays, &group_indices, opt_filter.as_ref(), @@ -134,29 +140,25 @@ unsafe extern "C" fn evaluate_fn_wrapper( accumulator: &FFI_GroupsAccumulator, emit_to: FFI_EmitTo, ) -> RResult { - let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; - let accum_data = &mut (*private_data); + let accumulator = accumulator.inner(); - let result = rresult_return!(accum_data.accumulator.evaluate(emit_to.into())); + let result = rresult_return!(accumulator.evaluate(emit_to.into())); rresult!(WrappedArray::try_from(&result)) } unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> usize { - let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; - let accum_data = &mut (*private_data); - - accum_data.accumulator.size() + let accumulator = accumulator.inner(); + accumulator.size() } unsafe extern "C" fn state_fn_wrapper( accumulator: &FFI_GroupsAccumulator, emit_to: FFI_EmitTo, ) -> RResult, RString> { - let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; - let accum_data = &mut (*private_data); + let accumulator = accumulator.inner(); - let state = rresult_return!(accum_data.accumulator.state(emit_to.into())); + let state = rresult_return!(accumulator.state(emit_to.into())); rresult!(state .into_iter() .map(|arr| WrappedArray::try_from(&arr).map_err(DataFusionError::from)) @@ -170,8 +172,7 @@ unsafe extern "C" fn merge_batch_fn_wrapper( opt_filter: ROption, total_num_groups: usize, ) -> RResult<(), RString> { - let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; - let accum_data = &mut (*private_data); + let accumulator = accumulator.inner(); let values_arrays = values .into_iter() .map(|v| v.try_into().map_err(DataFusionError::from)) @@ -191,7 +192,7 @@ unsafe extern "C" fn merge_batch_fn_wrapper( }).map(|arr| arr.into_data()); let opt_filter = maybe_filter.map(BooleanArray::from); - rresult!(accum_data.accumulator.merge_batch( + rresult!(accumulator.merge_batch( &values_arrays, &group_indices, opt_filter.as_ref(), @@ -204,9 +205,7 @@ unsafe extern "C" fn convert_to_state_fn_wrapper( values: RVec, opt_filter: ROption, ) -> RResult, RString> { - let private_data = accumulator.private_data as *mut GroupsAccumulatorPrivateData; - let accum_data = &mut (*private_data); - + let accumulator = accumulator.inner(); let values = rresult_return!(values .into_iter() .map(|v| ArrayRef::try_from(v).map_err(DataFusionError::from)) @@ -222,9 +221,8 @@ unsafe extern "C" fn convert_to_state_fn_wrapper( } }).map(|arr| arr.into_data()).map(BooleanArray::from); - let state = rresult_return!(accum_data - .accumulator - .convert_to_state(&values, opt_filter.as_ref())); + let state = + rresult_return!(accumulator.convert_to_state(&values, opt_filter.as_ref())); rresult!(state .iter() From b6da0e9d3138dbffaee53283e625aa216952752e Mon Sep 17 00:00:00 2001 From: Crystal Zhou Date: Fri, 11 Apr 2025 21:09:16 -0400 Subject: [PATCH 22/32] Add documentation --- datafusion/ffi/src/udaf/groups_accumulator.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs index 77878403741b..597566e4f51d 100644 --- a/datafusion/ffi/src/udaf/groups_accumulator.rs +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -36,6 +36,7 @@ use crate::{ df_result, rresult, rresult_return, }; +/// A stable struct for sharing [`GroupsAccumulator`] across FFI boundaries. #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] From a2fbda8d7c4527747109f52b3a3c4fbec64b1db6 Mon Sep 17 00:00:00 2001 From: Crystal Zhou Date: Fri, 11 Apr 2025 23:01:53 -0400 Subject: [PATCH 23/32] Split integration tests --- datafusion/ffi/tests/ffi_integration.rs | 99 ------------------ datafusion/ffi/tests/ffi_udaf.rs | 129 ++++++++++++++++++++++++ 2 files changed, 129 insertions(+), 99 deletions(-) create mode 100644 datafusion/ffi/tests/ffi_udaf.rs diff --git a/datafusion/ffi/tests/ffi_integration.rs b/datafusion/ffi/tests/ffi_integration.rs index af122438dd63..3fad88254645 100644 --- a/datafusion/ffi/tests/ffi_integration.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -99,103 +99,4 @@ mod tests { Ok(()) } - - #[tokio::test] - async fn test_ffi_udaf() -> Result<()> { - let module = get_module()?; - - let ffi_sum_func = - module - .create_sum_udaf() - .ok_or(DataFusionError::NotImplemented( - "External table provider failed to implement create_udaf".to_string(), - ))?(); - let foreign_sum_func: ForeignAggregateUDF = (&ffi_sum_func).try_into()?; - - let udaf: AggregateUDF = foreign_sum_func.into(); - - let ctx = SessionContext::default(); - let record_batch = record_batch!( - ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), - ("b", Float64, vec![1.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0]) - ) - .unwrap(); - - let df = ctx.read_batch(record_batch)?; - - let df = df - .aggregate( - vec![col("a")], - vec![udaf.call(vec![col("b")]).alias("sum_b")], - )? - .sort_by(vec![col("a")])?; - - let result = df.collect().await?; - - let expected = record_batch!( - ("a", Int32, vec![1, 2, 4]), - ("sum_b", Float64, vec![1.0, 4.0, 16.0]) - )?; - - assert_eq!(result[0], expected); - - Ok(()) - } - - #[tokio::test] - async fn test_ffi_grouping_udaf() -> Result<()> { - let module = get_module()?; - - let ffi_stddev_func = - module - .create_stddev_udaf() - .ok_or(DataFusionError::NotImplemented( - "External table provider failed to implement create_udaf".to_string(), - ))?(); - let foreign_stddev_func: ForeignAggregateUDF = (&ffi_stddev_func).try_into()?; - - let udaf: AggregateUDF = foreign_stddev_func.into(); - - let ctx = SessionContext::default(); - let record_batch = record_batch!( - ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), - ( - "b", - Float64, - vec![ - 1.0, - 2.0, - 2.0 + 2.0_f64.sqrt(), - 4.0, - 4.0, - 4.0 + 3.0_f64.sqrt(), - 4.0 + 3.0_f64.sqrt() - ] - ) - ) - .unwrap(); - - let df = ctx.read_batch(record_batch)?; - - let df = df - .aggregate( - vec![col("a")], - vec![udaf.call(vec![col("b")]).alias("stddev_b")], - )? - .sort_by(vec![col("a")])?; - - let result = df.collect().await?; - let result = result[0].column_by_name("stddev_b").unwrap(); - let result = result - .as_any() - .downcast_ref::() - .unwrap() - .values(); - - assert!(result.first().unwrap().is_nan()); - assert!(result.get(1).unwrap() - 1.0 < 0.00001); - assert!(result.get(2).unwrap() - 1.0 < 0.00001); - - Ok(()) - } } diff --git a/datafusion/ffi/tests/ffi_udaf.rs b/datafusion/ffi/tests/ffi_udaf.rs new file mode 100644 index 000000000000..31b1f473913c --- /dev/null +++ b/datafusion/ffi/tests/ffi_udaf.rs @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Add an additional module here for convenience to scope this to only +/// when the feature integtation-tests is built +#[cfg(feature = "integration-tests")] +mod tests { + use arrow::array::Float64Array; + use datafusion::common::record_batch; + use datafusion::error::{DataFusionError, Result}; + use datafusion::logical_expr::AggregateUDF; + use datafusion::prelude::{col, SessionContext}; + + use datafusion_ffi::tests::utils::get_module; + use datafusion_ffi::udaf::ForeignAggregateUDF; + + #[tokio::test] + async fn test_ffi_udaf() -> Result<()> { + let module = get_module()?; + + let ffi_sum_func = + module + .create_sum_udaf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_udaf".to_string(), + ))?(); + let foreign_sum_func: ForeignAggregateUDF = (&ffi_sum_func).try_into()?; + + let udaf: AggregateUDF = foreign_sum_func.into(); + + let ctx = SessionContext::default(); + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), + ("b", Float64, vec![1.0, 2.0, 2.0, 4.0, 4.0, 4.0, 4.0]) + ) + .unwrap(); + + let df = ctx.read_batch(record_batch)?; + + let df = df + .aggregate( + vec![col("a")], + vec![udaf.call(vec![col("b")]).alias("sum_b")], + )? + .sort_by(vec![col("a")])?; + + let result = df.collect().await?; + + let expected = record_batch!( + ("a", Int32, vec![1, 2, 4]), + ("sum_b", Float64, vec![1.0, 4.0, 16.0]) + )?; + + assert_eq!(result[0], expected); + + Ok(()) + } + + #[tokio::test] + async fn test_ffi_grouping_udaf() -> Result<()> { + let module = get_module()?; + + let ffi_stddev_func = + module + .create_stddev_udaf() + .ok_or(DataFusionError::NotImplemented( + "External table provider failed to implement create_udaf".to_string(), + ))?(); + let foreign_stddev_func: ForeignAggregateUDF = (&ffi_stddev_func).try_into()?; + + let udaf: AggregateUDF = foreign_stddev_func.into(); + + let ctx = SessionContext::default(); + let record_batch = record_batch!( + ("a", Int32, vec![1, 2, 2, 4, 4, 4, 4]), + ( + "b", + Float64, + vec![ + 1.0, + 2.0, + 2.0 + 2.0_f64.sqrt(), + 4.0, + 4.0, + 4.0 + 3.0_f64.sqrt(), + 4.0 + 3.0_f64.sqrt() + ] + ) + ) + .unwrap(); + + let df = ctx.read_batch(record_batch)?; + + let df = df + .aggregate( + vec![col("a")], + vec![udaf.call(vec![col("b")]).alias("stddev_b")], + )? + .sort_by(vec![col("a")])?; + + let result = df.collect().await?; + let result = result[0].column_by_name("stddev_b").unwrap(); + let result = result + .as_any() + .downcast_ref::() + .unwrap() + .values(); + + assert!(result.first().unwrap().is_nan()); + assert!(result.get(1).unwrap() - 1.0 < 0.00001); + assert!(result.get(2).unwrap() - 1.0 < 0.00001); + + Ok(()) + } +} From 0b4a8f57c3a4bb88af5263ce910b5db5d6471a63 Mon Sep 17 00:00:00 2001 From: Crystal Zhou Date: Fri, 11 Apr 2025 23:58:55 -0400 Subject: [PATCH 24/32] Address comments to refactor error handling for opt filter --- datafusion/ffi/src/udaf/groups_accumulator.rs | 81 +++++++------------ 1 file changed, 28 insertions(+), 53 deletions(-) diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs index 597566e4f51d..792c0b6f883f 100644 --- a/datafusion/ffi/src/udaf/groups_accumulator.rs +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::ffi::c_void; +use std::{ffi::c_void, sync::Arc}; use abi_stable::{ std_types::{ROption, RResult, RString, RVec}, @@ -101,6 +101,25 @@ impl FFI_GroupsAccumulator { } } +fn process_values(values: RVec) -> Result>> { + values + .into_iter() + .map(|v| v.try_into().map_err(DataFusionError::from)) + .collect::>>() +} + +/// Convert C-typed opt_filter into the internal type. +fn process_opt_filter(opt_filter: ROption) -> Result> { + opt_filter + .into_option() + .map(|filter| { + ArrayRef::try_from(filter) + .map_err(DataFusionError::from) + .map(|arr| BooleanArray::from(arr.into_data())) + }) + .transpose() +} + unsafe extern "C" fn update_batch_fn_wrapper( accumulator: &mut FFI_GroupsAccumulator, values: RVec, @@ -109,28 +128,12 @@ unsafe extern "C" fn update_batch_fn_wrapper( total_num_groups: usize, ) -> RResult<(), RString> { let accumulator = accumulator.inner(); - - let values_arrays = values - .into_iter() - .map(|v| v.try_into().map_err(DataFusionError::from)) - .collect::>>(); - let values_arrays = rresult_return!(values_arrays); - + let values = rresult_return!(process_values(values)); let group_indices: Vec = group_indices.into_iter().collect(); - - let maybe_filter = opt_filter.into_option().and_then(|filter| { - match ArrayRef::try_from(filter) { - Ok(v) => Some(v), - Err(e) => { - log::warn!("Error during FFI array conversion. Ignoring optional filter in groups accumulator. {}", e); - None - } - } - }).map(|arr| arr.into_data()); - let opt_filter = maybe_filter.map(BooleanArray::from); + let opt_filter = rresult_return!(process_opt_filter(opt_filter)); rresult!(accumulator.update_batch( - &values_arrays, + &values, &group_indices, opt_filter.as_ref(), total_num_groups @@ -174,27 +177,12 @@ unsafe extern "C" fn merge_batch_fn_wrapper( total_num_groups: usize, ) -> RResult<(), RString> { let accumulator = accumulator.inner(); - let values_arrays = values - .into_iter() - .map(|v| v.try_into().map_err(DataFusionError::from)) - .collect::>>(); - let values_arrays = rresult_return!(values_arrays); - + let values = rresult_return!(process_values(values)); let group_indices: Vec = group_indices.into_iter().collect(); - - let maybe_filter = opt_filter.into_option().and_then(|filter| { - match ArrayRef::try_from(filter) { - Ok(v) => Some(v), - Err(e) => { - log::warn!("Error during FFI array conversion. Ignoring optional filter in groups accumulator. {}", e); - None - } - } - }).map(|arr| arr.into_data()); - let opt_filter = maybe_filter.map(BooleanArray::from); + let opt_filter = rresult_return!(process_opt_filter(opt_filter)); rresult!(accumulator.merge_batch( - &values_arrays, + &values, &group_indices, opt_filter.as_ref(), total_num_groups @@ -207,21 +195,8 @@ unsafe extern "C" fn convert_to_state_fn_wrapper( opt_filter: ROption, ) -> RResult, RString> { let accumulator = accumulator.inner(); - let values = rresult_return!(values - .into_iter() - .map(|v| ArrayRef::try_from(v).map_err(DataFusionError::from)) - .collect::>>()); - - let opt_filter = opt_filter.into_option().and_then(|filter| { - match ArrayRef::try_from(filter) { - Ok(v) => Some(v), - Err(e) => { - log::warn!("Error during FFI array conversion. Ignoring optional filter in groups accumulator. {}", e); - None - } - } - }).map(|arr| arr.into_data()).map(BooleanArray::from); - + let values = rresult_return!(process_values(values)); + let opt_filter = rresult_return!(process_opt_filter(opt_filter)); let state = rresult_return!(accumulator.convert_to_state(&values, opt_filter.as_ref())); From 4b3c5335442cb3d025c559712be9c331409de06c Mon Sep 17 00:00:00 2001 From: Crystal Zhou Date: Sat, 12 Apr 2025 19:25:59 -0400 Subject: [PATCH 25/32] Fix linting errors --- datafusion/ffi/src/udaf/accumulator.rs | 24 ++++++++++----- datafusion/ffi/src/udaf/groups_accumulator.rs | 29 ++++++++++++------- datafusion/ffi/tests/ffi_integration.rs | 2 -- 3 files changed, 35 insertions(+), 20 deletions(-) diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs index 9178c0656041..497aabe7632c 100644 --- a/datafusion/ffi/src/udaf/accumulator.rs +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -42,12 +42,13 @@ pub struct FFI_Accumulator { ) -> RResult<(), RString>, // Evaluate and return a ScalarValues as protobuf bytes - pub evaluate: unsafe extern "C" fn(accumulator: &Self) -> RResult, RString>, + pub evaluate: + unsafe extern "C" fn(accumulator: &mut Self) -> RResult, RString>, pub size: unsafe extern "C" fn(accumulator: &Self) -> usize, pub state: - unsafe extern "C" fn(accumulator: &Self) -> RResult>, RString>, + unsafe extern "C" fn(accumulator: &mut Self) -> RResult>, RString>, pub merge_batch: unsafe extern "C" fn( accumulator: &mut Self, @@ -78,10 +79,16 @@ pub struct AccumulatorPrivateData { impl FFI_Accumulator { #[inline] - unsafe fn inner(&self) -> &mut Box { + unsafe fn inner(&mut self) -> &mut Box { let private_data = self.private_data as *mut AccumulatorPrivateData; &mut (*private_data).accumulator } + + #[inline] + unsafe fn inner_ref(&self) -> &dyn Accumulator { + let private_data = self.private_data as *const AccumulatorPrivateData; + &*(*private_data).accumulator + } } unsafe extern "C" fn update_batch_fn_wrapper( @@ -100,7 +107,7 @@ unsafe extern "C" fn update_batch_fn_wrapper( } unsafe extern "C" fn evaluate_fn_wrapper( - accumulator: &FFI_Accumulator, + accumulator: &mut FFI_Accumulator, ) -> RResult, RString> { let accumulator = accumulator.inner(); @@ -112,11 +119,11 @@ unsafe extern "C" fn evaluate_fn_wrapper( } unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_Accumulator) -> usize { - accumulator.inner().size() + accumulator.inner_ref().size() } unsafe extern "C" fn state_fn_wrapper( - accumulator: &FFI_Accumulator, + accumulator: &mut FFI_Accumulator, ) -> RResult>, RString> { let accumulator = accumulator.inner(); @@ -231,7 +238,7 @@ impl Accumulator for ForeignAccumulator { fn evaluate(&mut self) -> Result { unsafe { let scalar_bytes = - df_result!((self.accumulator.evaluate)(&self.accumulator))?; + df_result!((self.accumulator.evaluate)(&mut self.accumulator))?; let proto_scalar = datafusion_proto::protobuf::ScalarValue::decode(scalar_bytes.as_ref()) @@ -247,7 +254,8 @@ impl Accumulator for ForeignAccumulator { fn state(&mut self) -> Result> { unsafe { - let state_protos = df_result!((self.accumulator.state)(&self.accumulator))?; + let state_protos = + df_result!((self.accumulator.state)(&mut self.accumulator))?; state_protos .into_iter() diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs index 792c0b6f883f..a0c5cf49274e 100644 --- a/datafusion/ffi/src/udaf/groups_accumulator.rs +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -51,14 +51,14 @@ pub struct FFI_GroupsAccumulator { // Evaluate and return a ScalarValues as protobuf bytes pub evaluate: unsafe extern "C" fn( - accumulator: &Self, + accumulator: &mut Self, emit_to: FFI_EmitTo, ) -> RResult, pub size: unsafe extern "C" fn(accumulator: &Self) -> usize, pub state: unsafe extern "C" fn( - accumulator: &Self, + accumulator: &mut Self, emit_to: FFI_EmitTo, ) -> RResult, RString>, @@ -95,10 +95,17 @@ pub struct GroupsAccumulatorPrivateData { } impl FFI_GroupsAccumulator { - unsafe fn inner(&self) -> &mut Box { + #[inline] + unsafe fn inner(&mut self) -> &mut Box { let private_data = self.private_data as *mut GroupsAccumulatorPrivateData; &mut (*private_data).accumulator } + + #[inline] + unsafe fn inner_ref(&self) -> &dyn GroupsAccumulator { + let private_data = self.private_data as *const GroupsAccumulatorPrivateData; + &*(*private_data).accumulator + } } fn process_values(values: RVec) -> Result>> { @@ -141,7 +148,7 @@ unsafe extern "C" fn update_batch_fn_wrapper( } unsafe extern "C" fn evaluate_fn_wrapper( - accumulator: &FFI_GroupsAccumulator, + accumulator: &mut FFI_GroupsAccumulator, emit_to: FFI_EmitTo, ) -> RResult { let accumulator = accumulator.inner(); @@ -152,12 +159,12 @@ unsafe extern "C" fn evaluate_fn_wrapper( } unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> usize { - let accumulator = accumulator.inner(); + let accumulator = accumulator.inner_ref(); accumulator.size() } unsafe extern "C" fn state_fn_wrapper( - accumulator: &FFI_GroupsAccumulator, + accumulator: &mut FFI_GroupsAccumulator, emit_to: FFI_EmitTo, ) -> RResult, RString> { let accumulator = accumulator.inner(); @@ -194,7 +201,7 @@ unsafe extern "C" fn convert_to_state_fn_wrapper( values: RVec, opt_filter: ROption, ) -> RResult, RString> { - let accumulator = accumulator.inner(); + let accumulator = accumulator.inner_ref(); let values = rresult_return!(process_values(values)); let opt_filter = rresult_return!(process_opt_filter(opt_filter)); let state = @@ -298,7 +305,7 @@ impl GroupsAccumulator for ForeignGroupsAccumulator { fn evaluate(&mut self, emit_to: EmitTo) -> Result { unsafe { let return_array = df_result!((self.accumulator.evaluate)( - &self.accumulator, + &mut self.accumulator, emit_to.into() ))?; @@ -308,8 +315,10 @@ impl GroupsAccumulator for ForeignGroupsAccumulator { fn state(&mut self, emit_to: EmitTo) -> Result> { unsafe { - let returned_arrays = - df_result!((self.accumulator.state)(&self.accumulator, emit_to.into()))?; + let returned_arrays = df_result!((self.accumulator.state)( + &mut self.accumulator, + emit_to.into() + ))?; returned_arrays .into_iter() diff --git a/datafusion/ffi/tests/ffi_integration.rs b/datafusion/ffi/tests/ffi_integration.rs index 3fad88254645..adbead25ba99 100644 --- a/datafusion/ffi/tests/ffi_integration.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -19,8 +19,6 @@ /// when the feature integtation-tests is built #[cfg(feature = "integration-tests")] mod tests { - use arrow::array::Float64Array; - use datafusion::common::record_batch; use datafusion::error::{DataFusionError, Result}; use datafusion::prelude::{col, SessionContext}; use datafusion_ffi::catalog_provider::ForeignCatalogProvider; From 1b85dd9e12aecf58538a7fa6cc0dc1cedfe4b45b Mon Sep 17 00:00:00 2001 From: Crystal Zhou Date: Sat, 12 Apr 2025 19:53:50 -0400 Subject: [PATCH 26/32] Fix linting and add deref --- datafusion/ffi/src/udaf/accumulator.rs | 20 +++++++++---------- datafusion/ffi/src/udaf/groups_accumulator.rs | 20 +++++++++---------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs index 497aabe7632c..b548ebc48243 100644 --- a/datafusion/ffi/src/udaf/accumulator.rs +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::ffi::c_void; +use std::{ffi::c_void, ops::Deref}; use abi_stable::{ std_types::{RResult, RString, RVec}, @@ -79,15 +79,15 @@ pub struct AccumulatorPrivateData { impl FFI_Accumulator { #[inline] - unsafe fn inner(&mut self) -> &mut Box { + unsafe fn inner_mut(&mut self) -> &mut Box { let private_data = self.private_data as *mut AccumulatorPrivateData; &mut (*private_data).accumulator } #[inline] - unsafe fn inner_ref(&self) -> &dyn Accumulator { + unsafe fn inner(&self) -> &dyn Accumulator { let private_data = self.private_data as *const AccumulatorPrivateData; - &*(*private_data).accumulator + (*private_data).accumulator.deref() } } @@ -95,7 +95,7 @@ unsafe extern "C" fn update_batch_fn_wrapper( accumulator: &mut FFI_Accumulator, values: RVec, ) -> RResult<(), RString> { - let accumulator = accumulator.inner(); + let accumulator = accumulator.inner_mut(); let values_arrays = values .into_iter() @@ -109,7 +109,7 @@ unsafe extern "C" fn update_batch_fn_wrapper( unsafe extern "C" fn evaluate_fn_wrapper( accumulator: &mut FFI_Accumulator, ) -> RResult, RString> { - let accumulator = accumulator.inner(); + let accumulator = accumulator.inner_mut(); let scalar_result = rresult_return!(accumulator.evaluate()); let proto_result: datafusion_proto::protobuf::ScalarValue = @@ -119,13 +119,13 @@ unsafe extern "C" fn evaluate_fn_wrapper( } unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_Accumulator) -> usize { - accumulator.inner_ref().size() + accumulator.inner().size() } unsafe extern "C" fn state_fn_wrapper( accumulator: &mut FFI_Accumulator, ) -> RResult>, RString> { - let accumulator = accumulator.inner(); + let accumulator = accumulator.inner_mut(); let state = rresult_return!(accumulator.state()); let state = state @@ -145,7 +145,7 @@ unsafe extern "C" fn merge_batch_fn_wrapper( accumulator: &mut FFI_Accumulator, states: RVec, ) -> RResult<(), RString> { - let accumulator = accumulator.inner(); + let accumulator = accumulator.inner_mut(); let states = rresult_return!(states .into_iter() @@ -159,7 +159,7 @@ unsafe extern "C" fn retract_batch_fn_wrapper( accumulator: &mut FFI_Accumulator, values: RVec, ) -> RResult<(), RString> { - let accumulator = accumulator.inner(); + let accumulator = accumulator.inner_mut(); let values_arrays = values .into_iter() diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs index a0c5cf49274e..09b16ff913b6 100644 --- a/datafusion/ffi/src/udaf/groups_accumulator.rs +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{ffi::c_void, sync::Arc}; +use std::{ffi::c_void, ops::Deref, sync::Arc}; use abi_stable::{ std_types::{ROption, RResult, RString, RVec}, @@ -96,15 +96,15 @@ pub struct GroupsAccumulatorPrivateData { impl FFI_GroupsAccumulator { #[inline] - unsafe fn inner(&mut self) -> &mut Box { + unsafe fn inner_mut(&mut self) -> &mut Box { let private_data = self.private_data as *mut GroupsAccumulatorPrivateData; &mut (*private_data).accumulator } #[inline] - unsafe fn inner_ref(&self) -> &dyn GroupsAccumulator { + unsafe fn inner(&self) -> &dyn GroupsAccumulator { let private_data = self.private_data as *const GroupsAccumulatorPrivateData; - &*(*private_data).accumulator + (*private_data).accumulator.deref() } } @@ -134,7 +134,7 @@ unsafe extern "C" fn update_batch_fn_wrapper( opt_filter: ROption, total_num_groups: usize, ) -> RResult<(), RString> { - let accumulator = accumulator.inner(); + let accumulator = accumulator.inner_mut(); let values = rresult_return!(process_values(values)); let group_indices: Vec = group_indices.into_iter().collect(); let opt_filter = rresult_return!(process_opt_filter(opt_filter)); @@ -151,7 +151,7 @@ unsafe extern "C" fn evaluate_fn_wrapper( accumulator: &mut FFI_GroupsAccumulator, emit_to: FFI_EmitTo, ) -> RResult { - let accumulator = accumulator.inner(); + let accumulator = accumulator.inner_mut(); let result = rresult_return!(accumulator.evaluate(emit_to.into())); @@ -159,7 +159,7 @@ unsafe extern "C" fn evaluate_fn_wrapper( } unsafe extern "C" fn size_fn_wrapper(accumulator: &FFI_GroupsAccumulator) -> usize { - let accumulator = accumulator.inner_ref(); + let accumulator = accumulator.inner(); accumulator.size() } @@ -167,7 +167,7 @@ unsafe extern "C" fn state_fn_wrapper( accumulator: &mut FFI_GroupsAccumulator, emit_to: FFI_EmitTo, ) -> RResult, RString> { - let accumulator = accumulator.inner(); + let accumulator = accumulator.inner_mut(); let state = rresult_return!(accumulator.state(emit_to.into())); rresult!(state @@ -183,7 +183,7 @@ unsafe extern "C" fn merge_batch_fn_wrapper( opt_filter: ROption, total_num_groups: usize, ) -> RResult<(), RString> { - let accumulator = accumulator.inner(); + let accumulator = accumulator.inner_mut(); let values = rresult_return!(process_values(values)); let group_indices: Vec = group_indices.into_iter().collect(); let opt_filter = rresult_return!(process_opt_filter(opt_filter)); @@ -201,7 +201,7 @@ unsafe extern "C" fn convert_to_state_fn_wrapper( values: RVec, opt_filter: ROption, ) -> RResult, RString> { - let accumulator = accumulator.inner_ref(); + let accumulator = accumulator.inner(); let values = rresult_return!(process_values(values)); let opt_filter = rresult_return!(process_opt_filter(opt_filter)); let state = From 8a4de4a9971117e993037340ddd068445f41e93a Mon Sep 17 00:00:00 2001 From: Crystal Zhou Date: Tue, 3 Jun 2025 17:49:14 -0400 Subject: [PATCH 27/32] Remove extra tests and unnecessary code --- .../src/datasource/file_format/parquet.rs | 107 ------------------ datafusion/core/src/lib.rs | 1 - 2 files changed, 108 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index ab1e938336d4..6a5c19829c1c 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -1662,111 +1662,4 @@ mod tests { Ok(()) } - - #[tokio::test] - async fn test_memory_reservation_column_parallel() -> Result<()> { - async fn test_memory_reservation(global: ParquetOptions) -> Result<()> { - let field_a = Field::new("a", DataType::Utf8, false); - let field_b = Field::new("b", DataType::Utf8, false); - let schema = Arc::new(Schema::new(vec![field_a, field_b])); - let object_store_url = ObjectStoreUrl::local_filesystem(); - - let file_sink_config = FileSinkConfig { - original_url: String::default(), - object_store_url: object_store_url.clone(), - file_group: FileGroup::new(vec![PartitionedFile::new( - "/tmp".to_string(), - 1, - )]), - table_paths: vec![ListingTableUrl::parse("file:///")?], - output_schema: schema.clone(), - table_partition_cols: vec![], - insert_op: InsertOp::Overwrite, - keep_partition_by_columns: false, - file_extension: "parquet".into(), - }; - let parquet_sink = Arc::new(ParquetSink::new( - file_sink_config, - TableParquetOptions { - key_value_metadata: std::collections::HashMap::from([ - ("my-data".to_string(), Some("stuff".to_string())), - ("my-data-bool-key".to_string(), None), - ]), - global, - ..Default::default() - }, - )); - - // create data - let col_a: ArrayRef = Arc::new(StringArray::from(vec!["foo", "bar"])); - let col_b: ArrayRef = Arc::new(StringArray::from(vec!["baz", "baz"])); - let batch = - RecordBatch::try_from_iter(vec![("a", col_a), ("b", col_b)]).unwrap(); - - // create task context - let task_context = build_ctx(object_store_url.as_ref()); - assert_eq!( - task_context.memory_pool().reserved(), - 0, - "no bytes are reserved yet" - ); - - let mut write_task = FileSink::write_all( - parquet_sink.as_ref(), - Box::pin(RecordBatchStreamAdapter::new( - schema, - bounded_stream(batch, 1000), - )), - &task_context, - ); - - // incrementally poll and check for memory reservation - let mut reserved_bytes = 0; - while futures::poll!(&mut write_task).is_pending() { - reserved_bytes += task_context.memory_pool().reserved(); - tokio::time::sleep(Duration::from_micros(1)).await; - } - assert!( - reserved_bytes > 0, - "should have bytes reserved during write" - ); - assert_eq!( - task_context.memory_pool().reserved(), - 0, - "no leaking byte reservation" - ); - - Ok(()) - } - - let write_opts = ParquetOptions { - allow_single_file_parallelism: false, - ..Default::default() - }; - test_memory_reservation(write_opts) - .await - .expect("should track for non-parallel writes"); - - let row_parallel_write_opts = ParquetOptions { - allow_single_file_parallelism: true, - maximum_parallel_row_group_writers: 10, - maximum_buffered_record_batches_per_stream: 1, - ..Default::default() - }; - test_memory_reservation(row_parallel_write_opts) - .await - .expect("should track for row-parallel writes"); - - let col_parallel_write_opts = ParquetOptions { - allow_single_file_parallelism: true, - maximum_parallel_row_group_writers: 1, - maximum_buffered_record_batches_per_stream: 2, - ..Default::default() - }; - test_memory_reservation(col_parallel_write_opts) - .await - .expect("should track for column-parallel writes"); - - Ok(()) - } } diff --git a/datafusion/core/src/lib.rs b/datafusion/core/src/lib.rs index bce08231ba62..6956108e2df3 100644 --- a/datafusion/core/src/lib.rs +++ b/datafusion/core/src/lib.rs @@ -843,7 +843,6 @@ pub mod test; mod schema_equivalence; pub mod test_util; -// pub use test_util::bounded_stream; #[cfg(doctest)] doc_comment::doctest!("../../../README.md", readme_example_test); From 5e02c72272359f0b2f3141d951d9c6e185f222e5 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 4 Jun 2025 15:13:47 -0400 Subject: [PATCH 28/32] Adjustments to FFI aggregate functions after rebase on main --- datafusion/ffi/src/udaf/accumulator_args.rs | 17 +++---- datafusion/ffi/src/udaf/mod.rs | 53 ++++++++++++--------- 2 files changed, 40 insertions(+), 30 deletions(-) diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs index 3a25d09c4a55..1eda0a4c0274 100644 --- a/datafusion/ffi/src/udaf/accumulator_args.rs +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -25,6 +25,7 @@ use arrow::{ datatypes::{DataType, Schema}, ffi::FFI_ArrowSchema, }; +use arrow_schema::FieldRef; use datafusion::{ error::DataFusionError, logical_expr::function::AccumulatorArgs, physical_expr::LexOrdering, physical_plan::PhysicalExpr, prelude::SessionContext, @@ -45,7 +46,7 @@ use crate::arrow_wrappers::WrappedSchema; #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] pub struct FFI_AccumulatorArgs { - return_type: WrappedSchema, + return_field: WrappedSchema, schema: WrappedSchema, is_reversed: bool, name: RString, @@ -56,7 +57,7 @@ impl TryFrom> for FFI_AccumulatorArgs { type Error = DataFusionError; fn try_from(args: AccumulatorArgs) -> Result { - let return_type = WrappedSchema(FFI_ArrowSchema::try_from(args.return_type)?); + let return_field = WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?); let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?); let codec = DefaultPhysicalExtensionCodec {}; @@ -76,7 +77,7 @@ impl TryFrom> for FFI_AccumulatorArgs { let physical_expr_def = physical_expr_def.encode_to_vec().into(); Ok(Self { - return_type, + return_field, schema, is_reversed: args.is_reversed, name: args.name.into(), @@ -90,7 +91,7 @@ impl TryFrom> for FFI_AccumulatorArgs { /// data across the FFI boundary and turn it into owned data that /// AccumulatorArgs can then reference. pub struct ForeignAccumulatorArgs { - pub return_type: DataType, + pub return_field: FieldRef, pub schema: Schema, pub ignore_nulls: bool, pub ordering_req: LexOrdering, @@ -108,7 +109,7 @@ impl TryFrom for ForeignAccumulatorArgs { PhysicalAggregateExprNode::decode(value.physical_expr_def.as_ref()) .map_err(|e| DataFusionError::Execution(e.to_string()))?; - let return_type = (&value.return_type.0).try_into()?; + let return_field = Arc::new((&value.return_field.0).try_into()?); let schema = Schema::try_from(&value.schema.0)?; let default_ctx = SessionContext::new(); @@ -126,7 +127,7 @@ impl TryFrom for ForeignAccumulatorArgs { let exprs = parse_physical_exprs(&proto_def.expr, &default_ctx, &schema, &codex)?; Ok(Self { - return_type, + return_field, schema, ignore_nulls: proto_def.ignore_nulls, ordering_req, @@ -141,7 +142,7 @@ impl TryFrom for ForeignAccumulatorArgs { impl<'a> From<&'a ForeignAccumulatorArgs> for AccumulatorArgs<'a> { fn from(value: &'a ForeignAccumulatorArgs) -> Self { Self { - return_type: &value.return_type, + return_field: Arc::clone(&value.return_field), schema: &value.schema, ignore_nulls: value.ignore_nulls, ordering_req: &value.ordering_req, @@ -168,7 +169,7 @@ mod tests { fn test_round_trip_accumulator_args() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); let orig_args = AccumulatorArgs { - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), schema: &schema, ignore_nulls: false, ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index 544081c557dd..a35a1a4b69b3 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -25,11 +25,12 @@ use accumulator::{FFI_Accumulator, ForeignAccumulator}; use accumulator_args::{FFI_AccumulatorArgs, ForeignAccumulatorArgs}; use arrow::datatypes::{DataType, Field}; use arrow::ffi::FFI_ArrowSchema; +use arrow_schema::FieldRef; use datafusion::{ error::DataFusionError, logical_expr::{ function::{AccumulatorArgs, AggregateFunctionSimplification, StateFieldsArgs}, - type_coercion::functions::data_types_with_aggregate_udf, + type_coercion::functions::fields_with_aggregate_udf, utils::AggregateOrderSensitivity, Accumulator, GroupsAccumulator, }, @@ -48,6 +49,7 @@ use crate::{ volatility::FFI_Volatility, }; use prost::{DecodeError, Message}; +use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; mod accumulator; mod accumulator_args; @@ -99,8 +101,8 @@ pub struct FFI_AggregateUDF { pub state_fields: unsafe extern "C" fn( udaf: &FFI_AggregateUDF, name: &RStr, - input_types: RVec, - return_type: WrappedSchema, + input_fields: RVec, + return_field: WrappedSchema, ordering_fields: RVec>, is_distinct: bool, ) -> RResult>, RString>, @@ -246,27 +248,30 @@ unsafe extern "C" fn with_beneficial_ordering_fn_wrapper( unsafe extern "C" fn state_fields_fn_wrapper( udaf: &FFI_AggregateUDF, name: &RStr, - input_types: RVec, - return_type: WrappedSchema, + input_fields: RVec, + return_field: WrappedSchema, ordering_fields: RVec>, is_distinct: bool, ) -> RResult>, RString> { let udaf = udaf.inner(); - let input_types = &rresult_return!(rvec_wrapped_to_vec_datatype(&input_types)); - let return_type = &rresult_return!(DataType::try_from(&return_type.0)); + let input_fields = &rresult_return!(rvec_wrapped_to_vec_fieldref(&input_fields)); + let return_field = rresult_return!(Field::try_from(&return_field.0)).into(); let ordering_fields = &rresult_return!(ordering_fields .into_iter() .map(|field_bytes| datafusion_proto_common::Field::decode(field_bytes.as_ref())) .collect::, DecodeError>>()); - let ordering_fields = &rresult_return!(parse_proto_fields_to_fields(ordering_fields)); + let ordering_fields = &rresult_return!(parse_proto_fields_to_fields(ordering_fields)) + .into_iter() + .map(Arc::new) + .collect::>(); let args = StateFieldsArgs { name: name.as_str(), - input_types, - return_type, + input_fields, + return_field, ordering_fields, is_distinct, }; @@ -274,6 +279,7 @@ unsafe extern "C" fn state_fields_fn_wrapper( let state_fields = rresult_return!(udaf.state_fields(args)); let state_fields = rresult_return!(state_fields .iter() + .map(|f| f.as_ref()) .map(datafusion_proto::protobuf::Field::try_from) .map(|v| v.map_err(DataFusionError::from)) .collect::>>()) @@ -298,7 +304,8 @@ unsafe extern "C" fn coerce_types_fn_wrapper( let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); - let return_types = rresult_return!(data_types_with_aggregate_udf(&arg_types, udaf)); + let arg_fields = arg_types.iter().map(|dt| Field::new("f", dt.clone(), true)).map(Arc::new).collect::>(); + let return_types = rresult_return!(fields_with_aggregate_udf(&arg_fields, udaf)).into_iter().map(|f| f.data_type().to_owned()).collect::>(); rresult!(vec_datatype_to_rvec_wrapped(&return_types)) } @@ -421,14 +428,15 @@ impl AggregateUDFImpl for ForeignAggregateUDF { } } - fn state_fields(&self, args: StateFieldsArgs) -> Result> { + fn state_fields(&self, args: StateFieldsArgs) -> Result> { unsafe { let name = RStr::from_str(args.name); - let input_types = vec_datatype_to_rvec_wrapped(args.input_types)?; - let return_type = WrappedSchema(FFI_ArrowSchema::try_from(args.return_type)?); + let input_fields = vec_fieldref_to_rvec_wrapped(args.input_fields)?; + let return_field = WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?); let ordering_fields = args .ordering_fields .iter() + .map(|f| f.as_ref()) .map(datafusion_proto::protobuf::Field::try_from) .map(|v| v.map_err(DataFusionError::from)) .collect::>>()? @@ -439,8 +447,8 @@ impl AggregateUDFImpl for ForeignAggregateUDF { let fields = df_result!((self.udaf.state_fields)( &self.udaf, &name, - input_types, - return_type, + input_fields, + return_field, ordering_fields, args.is_distinct ))?; @@ -453,6 +461,7 @@ impl AggregateUDFImpl for ForeignAggregateUDF { .collect::>>()?; parse_proto_fields_to_fields(fields.iter()) + .map(|fields| fields.into_iter().map(Arc::new).collect()) .map_err(|e| DataFusionError::Execution(e.to_string())) } } @@ -624,7 +633,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); let acc_args = AccumulatorArgs { - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), schema: &schema, ignore_nulls: true, ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { @@ -658,12 +667,12 @@ mod tests { AggregateOrderSensitivity::Beneficial ); - let a_field = Field::new("a", DataType::Float64, true); + let a_field = Arc::new(Field::new("a", DataType::Float64, true)); let state_fields = foreign_udaf.state_fields(StateFieldsArgs { name: "a", - input_types: &[DataType::Float64], - return_type: &DataType::Float64, - ordering_fields: &[a_field.clone()], + input_fields: &[Field::new("f", DataType::Float64, true).into()], + return_field: Field::new("f", DataType::Float64, true).into(), + ordering_fields: &[Arc::clone(&a_field)], is_distinct: false, })?; @@ -679,7 +688,7 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]); let acc_args = AccumulatorArgs { - return_type: &DataType::Float64, + return_field: Field::new("f", DataType::Float64, true).into(), schema: &schema, ignore_nulls: true, ordering_req: &LexOrdering::new(vec![PhysicalSortExpr { From d128b85cc65a027d815e7afac60c1ee8873f4c70 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 4 Jun 2025 15:14:01 -0400 Subject: [PATCH 29/32] cargo fmt --- datafusion/ffi/src/tests/mod.rs | 5 ++++- datafusion/ffi/src/tests/udf_udaf_udwf.rs | 1 - datafusion/ffi/src/udaf/accumulator_args.rs | 3 ++- datafusion/ffi/src/udaf/mod.rs | 16 ++++++++++++---- datafusion/ffi/tests/ffi_integration.rs | 4 ++-- 5 files changed, 20 insertions(+), 9 deletions(-) diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index 7a854fd1c33e..f65ed7441b42 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -39,7 +39,10 @@ use datafusion::{ common::record_batch, }; use sync_provider::create_sync_table_provider; -use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_random_func, create_ffi_stddev_func, create_ffi_sum_func, create_ffi_table_func}; +use udf_udaf_udwf::{ + create_ffi_abs_func, create_ffi_random_func, create_ffi_stddev_func, + create_ffi_sum_func, create_ffi_table_func, +}; mod async_provider; pub mod catalog; diff --git a/datafusion/ffi/src/tests/udf_udaf_udwf.rs b/datafusion/ffi/src/tests/udf_udaf_udwf.rs index b2a4194fa355..6aa69bdd0c4a 100644 --- a/datafusion/ffi/src/tests/udf_udaf_udwf.rs +++ b/datafusion/ffi/src/tests/udf_udaf_udwf.rs @@ -44,7 +44,6 @@ pub(crate) extern "C" fn create_ffi_table_func() -> FFI_TableFunction { FFI_TableFunction::new(udtf, None) } - pub(crate) extern "C" fn create_ffi_sum_func() -> FFI_AggregateUDF { let udaf: Arc = Arc::new(Sum::new().into()); diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs index 1eda0a4c0274..a16490ae9ebd 100644 --- a/datafusion/ffi/src/udaf/accumulator_args.rs +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -57,7 +57,8 @@ impl TryFrom> for FFI_AccumulatorArgs { type Error = DataFusionError; fn try_from(args: AccumulatorArgs) -> Result { - let return_field = WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?); + let return_field = + WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?); let schema = WrappedSchema(FFI_ArrowSchema::try_from(args.schema)?); let codec = DefaultPhysicalExtensionCodec {}; diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index a35a1a4b69b3..449c9d9a2a5b 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -42,6 +42,7 @@ use datafusion::{ use datafusion_proto_common::from_proto::parse_proto_fields_to_fields; use groups_accumulator::{FFI_GroupsAccumulator, ForeignGroupsAccumulator}; +use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; use crate::{ arrow_wrappers::WrappedSchema, df_result, rresult, rresult_return, @@ -49,7 +50,6 @@ use crate::{ volatility::FFI_Volatility, }; use prost::{DecodeError, Message}; -use crate::util::{rvec_wrapped_to_vec_fieldref, vec_fieldref_to_rvec_wrapped}; mod accumulator; mod accumulator_args; @@ -304,8 +304,15 @@ unsafe extern "C" fn coerce_types_fn_wrapper( let arg_types = rresult_return!(rvec_wrapped_to_vec_datatype(&arg_types)); - let arg_fields = arg_types.iter().map(|dt| Field::new("f", dt.clone(), true)).map(Arc::new).collect::>(); - let return_types = rresult_return!(fields_with_aggregate_udf(&arg_fields, udaf)).into_iter().map(|f| f.data_type().to_owned()).collect::>(); + let arg_fields = arg_types + .iter() + .map(|dt| Field::new("f", dt.clone(), true)) + .map(Arc::new) + .collect::>(); + let return_types = rresult_return!(fields_with_aggregate_udf(&arg_fields, udaf)) + .into_iter() + .map(|f| f.data_type().to_owned()) + .collect::>(); rresult!(vec_datatype_to_rvec_wrapped(&return_types)) } @@ -432,7 +439,8 @@ impl AggregateUDFImpl for ForeignAggregateUDF { unsafe { let name = RStr::from_str(args.name); let input_fields = vec_fieldref_to_rvec_wrapped(args.input_fields)?; - let return_field = WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?); + let return_field = + WrappedSchema(FFI_ArrowSchema::try_from(args.return_field.as_ref())?); let ordering_fields = args .ordering_fields .iter() diff --git a/datafusion/ffi/tests/ffi_integration.rs b/datafusion/ffi/tests/ffi_integration.rs index adbead25ba99..bf1910c31279 100644 --- a/datafusion/ffi/tests/ffi_integration.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -20,12 +20,12 @@ #[cfg(feature = "integration-tests")] mod tests { use datafusion::error::{DataFusionError, Result}; + use datafusion::logical_expr::{AggregateUDF, ScalarUDF}; use datafusion::prelude::{col, SessionContext}; use datafusion_ffi::catalog_provider::ForeignCatalogProvider; use datafusion_ffi::table_provider::ForeignTableProvider; - use datafusion_ffi::tests::{create_record_batch, ForeignLibraryModuleRef}; use datafusion_ffi::tests::utils::get_module; - use datafusion::logical_expr::{AggregateUDF, ScalarUDF}; + use datafusion_ffi::tests::{create_record_batch, ForeignLibraryModuleRef}; use datafusion_ffi::udaf::ForeignAggregateUDF; use std::sync::Arc; From 4282c2ad49c662438e2e3c7329c9b98145eeef84 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 4 Jun 2025 15:20:46 -0400 Subject: [PATCH 30/32] cargo clippy --- datafusion/ffi/src/arrow_wrappers.rs | 19 ------------------- datafusion/ffi/src/udaf/accumulator_args.rs | 10 +++------- datafusion/ffi/src/udaf/mod.rs | 5 ++--- datafusion/ffi/tests/ffi_integration.rs | 6 ++---- 4 files changed, 7 insertions(+), 33 deletions(-) diff --git a/datafusion/ffi/src/arrow_wrappers.rs b/datafusion/ffi/src/arrow_wrappers.rs index 64bedb9cfe67..9d2d0cd25034 100644 --- a/datafusion/ffi/src/arrow_wrappers.rs +++ b/datafusion/ffi/src/arrow_wrappers.rs @@ -32,16 +32,6 @@ use log::error; #[derive(Debug, StableAbi)] pub struct WrappedSchema(#[sabi(unsafe_opaque_field)] pub FFI_ArrowSchema); -/// Some functions are expected to always succeed, like getting the schema from a TableProvider. -/// Since going through the FFI always has the potential to fail, we need to catch these errors, -/// give the user a warning, and return some kind of result. In this case we default to an -/// empty schema. -#[cfg(not(tarpaulin_include))] -fn catch_ffi_schema_error(e: ArrowError) -> FFI_ArrowSchema { - error!("Unable to convert DataFusion Schema to FFI_ArrowSchema in FFI_PlanProperties. {}", e); - FFI_ArrowSchema::empty() -} - impl From for WrappedSchema { fn from(value: SchemaRef) -> Self { let ffi_schema = match FFI_ArrowSchema::try_from(value.as_ref()) { @@ -55,15 +45,6 @@ impl From for WrappedSchema { WrappedSchema(ffi_schema) } } -/// Some functions are expected to always succeed, like getting the schema from a TableProvider. -/// Since going through the FFI always has the potential to fail, we need to catch these errors, -/// give the user a warning, and return some kind of result. In this case we default to an -/// empty schema. -#[cfg(not(tarpaulin_include))] -fn catch_df_schema_error(e: ArrowError) -> Schema { - error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {}", e); - Schema::empty() -} impl From for SchemaRef { fn from(value: WrappedSchema) -> Self { diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs index a16490ae9ebd..e8a4afff4184 100644 --- a/datafusion/ffi/src/udaf/accumulator_args.rs +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -21,10 +21,7 @@ use abi_stable::{ std_types::{RString, RVec}, StableAbi, }; -use arrow::{ - datatypes::{DataType, Schema}, - ffi::FFI_ArrowSchema, -}; +use arrow::{datatypes::Schema, ffi::FFI_ArrowSchema}; use arrow_schema::FieldRef; use datafusion::{ error::DataFusionError, logical_expr::function::AccumulatorArgs, @@ -182,18 +179,17 @@ mod tests { is_distinct: true, exprs: &[col("a", &schema)?], }; - let orig_str = format!("{:?}", orig_args); + let orig_str = format!("{orig_args:?}"); let ffi_args: FFI_AccumulatorArgs = orig_args.try_into()?; let foreign_args: ForeignAccumulatorArgs = ffi_args.try_into()?; let round_trip_args: AccumulatorArgs = (&foreign_args).into(); - let round_trip_str = format!("{:?}", round_trip_args); + let round_trip_str = format!("{round_trip_args:?}"); // Since AccumulatorArgs doesn't implement Eq, simply compare // the debug strings. assert_eq!(orig_str, round_trip_str); - println!("{}", round_trip_str); Ok(()) } diff --git a/datafusion/ffi/src/udaf/mod.rs b/datafusion/ffi/src/udaf/mod.rs index 449c9d9a2a5b..2529ed7a06dc 100644 --- a/datafusion/ffi/src/udaf/mod.rs +++ b/datafusion/ffi/src/udaf/mod.rs @@ -224,7 +224,7 @@ unsafe extern "C" fn groups_accumulator_supported_fn_wrapper( ForeignAccumulatorArgs::try_from(args) .map(|a| udaf.groups_accumulator_supported((&a).into())) .unwrap_or_else(|e| { - log::warn!("Unable to parse accumulator args. {}", e); + log::warn!("Unable to parse accumulator args. {e}"); false }) } @@ -478,7 +478,7 @@ impl AggregateUDFImpl for ForeignAggregateUDF { let args = match FFI_AccumulatorArgs::try_from(args) { Ok(v) => v, Err(e) => { - log::warn!("Attempting to convert accumulator arguments: {}", e); + log::warn!("Attempting to convert accumulator arguments: {e}"); return false; } }; @@ -684,7 +684,6 @@ mod tests { is_distinct: false, })?; - println!("{:#?}", state_fields); assert_eq!(state_fields.len(), 3); assert_eq!(state_fields[1], a_field); Ok(()) diff --git a/datafusion/ffi/tests/ffi_integration.rs b/datafusion/ffi/tests/ffi_integration.rs index bf1910c31279..1ef16fbaa4d8 100644 --- a/datafusion/ffi/tests/ffi_integration.rs +++ b/datafusion/ffi/tests/ffi_integration.rs @@ -20,13 +20,11 @@ #[cfg(feature = "integration-tests")] mod tests { use datafusion::error::{DataFusionError, Result}; - use datafusion::logical_expr::{AggregateUDF, ScalarUDF}; - use datafusion::prelude::{col, SessionContext}; + use datafusion::prelude::SessionContext; use datafusion_ffi::catalog_provider::ForeignCatalogProvider; use datafusion_ffi::table_provider::ForeignTableProvider; + use datafusion_ffi::tests::create_record_batch; use datafusion_ffi::tests::utils::get_module; - use datafusion_ffi::tests::{create_record_batch, ForeignLibraryModuleRef}; - use datafusion_ffi::udaf::ForeignAggregateUDF; use std::sync::Arc; /// It is important that this test is in the `tests` directory and not in the From bb84d08fbd1217bd98bb20a56cc82c5a7ab0107d Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 4 Jun 2025 15:27:30 -0400 Subject: [PATCH 31/32] Re-implement cleaned up code that was removed in last push --- datafusion/ffi/src/arrow_wrappers.rs | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/datafusion/ffi/src/arrow_wrappers.rs b/datafusion/ffi/src/arrow_wrappers.rs index 9d2d0cd25034..7b3751dcae82 100644 --- a/datafusion/ffi/src/arrow_wrappers.rs +++ b/datafusion/ffi/src/arrow_wrappers.rs @@ -45,16 +45,19 @@ impl From for WrappedSchema { WrappedSchema(ffi_schema) } } +/// Some functions are expected to always succeed, like getting the schema from a TableProvider. +/// Since going through the FFI always has the potential to fail, we need to catch these errors, +/// give the user a warning, and return some kind of result. In this case we default to an +/// empty schema. +#[cfg(not(tarpaulin_include))] +fn catch_df_schema_error(e: ArrowError) -> Schema { + error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {e}"); + Schema::empty() +} impl From for SchemaRef { fn from(value: WrappedSchema) -> Self { - let schema = match Schema::try_from(&value.0) { - Ok(s) => s, - Err(e) => { - error!("Unable to convert from FFI_ArrowSchema to DataFusion Schema in FFI_PlanProperties. {e}"); - Schema::empty() - } - }; + let schema = Schema::try_from(&value.0).unwrap_or_else(catch_df_schema_error); Arc::new(schema) } } From 1c3fad9bb10efda72d25e77780670b2593b66832 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 4 Jun 2025 15:38:21 -0400 Subject: [PATCH 32/32] Minor review comments --- datafusion/ffi/src/udaf/accumulator.rs | 2 ++ datafusion/ffi/src/udaf/accumulator_args.rs | 6 ++++-- datafusion/ffi/src/udaf/groups_accumulator.rs | 11 ++++++----- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/datafusion/ffi/src/udaf/accumulator.rs b/datafusion/ffi/src/udaf/accumulator.rs index b548ebc48243..80b872159f48 100644 --- a/datafusion/ffi/src/udaf/accumulator.rs +++ b/datafusion/ffi/src/udaf/accumulator.rs @@ -32,6 +32,8 @@ use prost::Message; use crate::{arrow_wrappers::WrappedArray, df_result, rresult, rresult_return}; /// A stable struct for sharing [`Accumulator`] across FFI boundaries. +/// For an explanation of each field, see the corresponding function +/// defined in [`Accumulator`]. #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] diff --git a/datafusion/ffi/src/udaf/accumulator_args.rs b/datafusion/ffi/src/udaf/accumulator_args.rs index e8a4afff4184..699af1d5c5e0 100644 --- a/datafusion/ffi/src/udaf/accumulator_args.rs +++ b/datafusion/ffi/src/udaf/accumulator_args.rs @@ -17,6 +17,7 @@ use std::sync::Arc; +use crate::arrow_wrappers::WrappedSchema; use abi_stable::{ std_types::{RString, RVec}, StableAbi, @@ -37,8 +38,9 @@ use datafusion_proto::{ }; use prost::Message; -use crate::arrow_wrappers::WrappedSchema; - +/// A stable struct for sharing [`AccumulatorArgs`] across FFI boundaries. +/// For an explanation of each field, see the corresponding field +/// defined in [`AccumulatorArgs`]. #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] diff --git a/datafusion/ffi/src/udaf/groups_accumulator.rs b/datafusion/ffi/src/udaf/groups_accumulator.rs index 09b16ff913b6..58a18c69db7c 100644 --- a/datafusion/ffi/src/udaf/groups_accumulator.rs +++ b/datafusion/ffi/src/udaf/groups_accumulator.rs @@ -17,6 +17,10 @@ use std::{ffi::c_void, ops::Deref, sync::Arc}; +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, +}; use abi_stable::{ std_types::{ROption, RResult, RString, RVec}, StableAbi, @@ -31,12 +35,9 @@ use datafusion::{ logical_expr::{EmitTo, GroupsAccumulator}, }; -use crate::{ - arrow_wrappers::{WrappedArray, WrappedSchema}, - df_result, rresult, rresult_return, -}; - /// A stable struct for sharing [`GroupsAccumulator`] across FFI boundaries. +/// For an explanation of each field, see the corresponding function +/// defined in [`GroupsAccumulator`]. #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)]