diff --git a/Cargo.lock b/Cargo.lock index cc771331ebb3..a503c63b998d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2163,6 +2163,7 @@ dependencies = [ "bytes", "dashmap", "datafusion", + "datafusion-ffi", "datafusion-proto", "env_logger", "futures", @@ -2237,6 +2238,7 @@ version = "46.0.1" dependencies = [ "abi_stable", "arrow", + "arrow-schema", "async-ffi", "async-trait", "datafusion", diff --git a/datafusion-examples/Cargo.toml b/datafusion-examples/Cargo.toml index f6b7d641d126..2ba1673d97b9 100644 --- a/datafusion-examples/Cargo.toml +++ b/datafusion-examples/Cargo.toml @@ -62,6 +62,7 @@ bytes = { workspace = true } dashmap = { workspace = true } # note only use main datafusion crate for examples datafusion = { workspace = true, default-features = true } +datafusion-ffi = { workspace = true } datafusion-proto = { workspace = true } env_logger = { workspace = true } futures = { workspace = true } diff --git a/datafusion/ffi/Cargo.toml b/datafusion/ffi/Cargo.toml index 5c80c1b04225..29f40df51444 100644 --- a/datafusion/ffi/Cargo.toml +++ b/datafusion/ffi/Cargo.toml @@ -40,6 +40,7 @@ crate-type = ["cdylib", "rlib"] [dependencies] abi_stable = "0.11.3" arrow = { workspace = true, features = ["ffi"] } +arrow-schema = { workspace = true } async-ffi = { version = "0.5.0", features = ["abi_stable"] } async-trait = { workspace = true } datafusion = { workspace = true, default-features = false } diff --git a/datafusion/ffi/src/lib.rs b/datafusion/ffi/src/lib.rs index 877129fc5bb1..d877e182a1d8 100644 --- a/datafusion/ffi/src/lib.rs +++ b/datafusion/ffi/src/lib.rs @@ -35,6 +35,7 @@ pub mod session_config; pub mod table_provider; pub mod table_source; pub mod udf; +pub mod udtf; pub mod util; pub mod volatility; diff --git a/datafusion/ffi/src/tests/mod.rs b/datafusion/ffi/src/tests/mod.rs index c7a9816431e1..7a36ee52bdb4 100644 --- a/datafusion/ffi/src/tests/mod.rs +++ b/datafusion/ffi/src/tests/mod.rs @@ -27,7 +27,7 @@ use abi_stable::{ }; use catalog::create_catalog_provider; -use crate::catalog_provider::FFI_CatalogProvider; +use crate::{catalog_provider::FFI_CatalogProvider, udtf::FFI_TableFunction}; use super::{table_provider::FFI_TableProvider, udf::FFI_ScalarUDF}; use arrow::array::RecordBatch; @@ -37,7 +37,7 @@ use datafusion::{ common::record_batch, }; use sync_provider::create_sync_table_provider; -use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_random_func}; +use udf_udaf_udwf::{create_ffi_abs_func, create_ffi_random_func, create_ffi_table_func}; mod async_provider; pub mod catalog; @@ -63,6 +63,8 @@ pub struct ForeignLibraryModule { pub create_nullary_udf: extern "C" fn() -> FFI_ScalarUDF, + pub create_table_function: extern "C" fn() -> FFI_TableFunction, + pub version: extern "C" fn() -> u64, } @@ -109,6 +111,7 @@ pub fn get_foreign_library_module() -> ForeignLibraryModuleRef { create_table: construct_table_provider, create_scalar_udf: create_ffi_abs_func, create_nullary_udf: create_ffi_random_func, + create_table_function: create_ffi_table_func, version: super::version, } .leak_into_prefix() diff --git a/datafusion/ffi/src/tests/udf_udaf_udwf.rs b/datafusion/ffi/src/tests/udf_udaf_udwf.rs index b40bec762bd7..c3cb1bcc3533 100644 --- a/datafusion/ffi/src/tests/udf_udaf_udwf.rs +++ b/datafusion/ffi/src/tests/udf_udaf_udwf.rs @@ -15,9 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::udf::FFI_ScalarUDF; +use crate::{udf::FFI_ScalarUDF, udtf::FFI_TableFunction}; use datafusion::{ + catalog::TableFunctionImpl, functions::math::{abs::AbsFunc, random::RandomFunc}, + functions_table::generate_series::RangeFunc, logical_expr::ScalarUDF, }; @@ -34,3 +36,9 @@ pub(crate) extern "C" fn create_ffi_random_func() -> FFI_ScalarUDF { udf.into() } + +pub(crate) extern "C" fn create_ffi_table_func() -> FFI_TableFunction { + let udtf: Arc = Arc::new(RangeFunc {}); + + FFI_TableFunction::new(udtf, None) +} diff --git a/datafusion/ffi/src/udtf.rs b/datafusion/ffi/src/udtf.rs new file mode 100644 index 000000000000..1e06247546be --- /dev/null +++ b/datafusion/ffi/src/udtf.rs @@ -0,0 +1,321 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ffi::c_void, sync::Arc}; + +use abi_stable::{ + std_types::{RResult, RString, RVec}, + StableAbi, +}; + +use datafusion::error::Result; +use datafusion::{ + catalog::{TableFunctionImpl, TableProvider}, + prelude::{Expr, SessionContext}, +}; +use datafusion_proto::{ + logical_plan::{ + from_proto::parse_exprs, to_proto::serialize_exprs, DefaultLogicalExtensionCodec, + }, + protobuf::LogicalExprList, +}; +use prost::Message; +use tokio::runtime::Handle; + +use crate::{ + df_result, rresult_return, + table_provider::{FFI_TableProvider, ForeignTableProvider}, +}; + +/// A stable struct for sharing a [`TableFunctionImpl`] across FFI boundaries. +#[repr(C)] +#[derive(Debug, StableAbi)] +#[allow(non_camel_case_types)] +pub struct FFI_TableFunction { + /// Equivalent to the `call` function of the TableFunctionImpl. + /// The arguments are Expr passed as protobuf encoded bytes. + pub call: unsafe extern "C" fn( + udtf: &Self, + args: RVec, + ) -> RResult, + + /// Used to create a clone on the provider of the udtf. This should + /// only need to be called by the receiver of the udtf. + pub clone: unsafe extern "C" fn(udtf: &Self) -> Self, + + /// Release the memory of the private data when it is no longer being used. + pub release: unsafe extern "C" fn(udtf: &mut Self), + + /// Internal data. This is only to be accessed by the provider of the udtf. + /// A [`ForeignTableFunction`] should never attempt to access this data. + pub private_data: *mut c_void, +} + +unsafe impl Send for FFI_TableFunction {} +unsafe impl Sync for FFI_TableFunction {} + +pub struct TableFunctionPrivateData { + udtf: Arc, + runtime: Option, +} + +impl FFI_TableFunction { + fn inner(&self) -> &Arc { + let private_data = self.private_data as *const TableFunctionPrivateData; + unsafe { &(*private_data).udtf } + } + + fn runtime(&self) -> Option { + let private_data = self.private_data as *const TableFunctionPrivateData; + unsafe { (*private_data).runtime.clone() } + } +} + +unsafe extern "C" fn call_fn_wrapper( + udtf: &FFI_TableFunction, + args: RVec, +) -> RResult { + let runtime = udtf.runtime(); + let udtf = udtf.inner(); + + let default_ctx = SessionContext::new(); + let codec = DefaultLogicalExtensionCodec {}; + + let proto_filters = rresult_return!(LogicalExprList::decode(args.as_ref())); + + let args = + rresult_return!(parse_exprs(proto_filters.expr.iter(), &default_ctx, &codec)); + + let table_provider = rresult_return!(udtf.call(&args)); + RResult::ROk(FFI_TableProvider::new(table_provider, false, runtime)) +} + +unsafe extern "C" fn release_fn_wrapper(udtf: &mut FFI_TableFunction) { + let private_data = Box::from_raw(udtf.private_data as *mut TableFunctionPrivateData); + drop(private_data); +} + +unsafe extern "C" fn clone_fn_wrapper(udtf: &FFI_TableFunction) -> FFI_TableFunction { + let runtime = udtf.runtime(); + let udtf = udtf.inner(); + + FFI_TableFunction::new(Arc::clone(udtf), runtime) +} + +impl Clone for FFI_TableFunction { + fn clone(&self) -> Self { + unsafe { (self.clone)(self) } + } +} + +impl FFI_TableFunction { + pub fn new(udtf: Arc, runtime: Option) -> Self { + let private_data = Box::new(TableFunctionPrivateData { udtf, runtime }); + + Self { + call: call_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl From> for FFI_TableFunction { + fn from(udtf: Arc) -> Self { + let private_data = Box::new(TableFunctionPrivateData { + udtf, + runtime: None, + }); + + Self { + call: call_fn_wrapper, + clone: clone_fn_wrapper, + release: release_fn_wrapper, + private_data: Box::into_raw(private_data) as *mut c_void, + } + } +} + +impl Drop for FFI_TableFunction { + fn drop(&mut self) { + unsafe { (self.release)(self) } + } +} + +/// This struct is used to access an UDTF provided by a foreign +/// library across a FFI boundary. +/// +/// The ForeignTableFunction is to be used by the caller of the UDTF, so it has +/// no knowledge or access to the private data. All interaction with the UDTF +/// must occur through the functions defined in FFI_TableFunction. +#[derive(Debug)] +pub struct ForeignTableFunction(FFI_TableFunction); + +unsafe impl Send for ForeignTableFunction {} +unsafe impl Sync for ForeignTableFunction {} + +impl From for ForeignTableFunction { + fn from(value: FFI_TableFunction) -> Self { + Self(value) + } +} + +impl TableFunctionImpl for ForeignTableFunction { + fn call(&self, args: &[Expr]) -> Result> { + let codec = DefaultLogicalExtensionCodec {}; + let expr_list = LogicalExprList { + expr: serialize_exprs(args, &codec)?, + }; + let filters_serialized = expr_list.encode_to_vec().into(); + + let table_provider = unsafe { (self.0.call)(&self.0, filters_serialized) }; + + let table_provider = df_result!(table_provider)?; + let table_provider: ForeignTableProvider = (&table_provider).into(); + + Ok(Arc::new(table_provider)) + } +} + +#[cfg(test)] +mod tests { + use arrow::{ + array::{ + record_batch, ArrayRef, Float64Array, RecordBatch, StringArray, UInt64Array, + }, + datatypes::{DataType, Field, Schema}, + }; + use datafusion::{ + catalog::MemTable, common::exec_err, prelude::lit, scalar::ScalarValue, + }; + + use super::*; + + #[derive(Debug)] + struct TestUDTF {} + + impl TableFunctionImpl for TestUDTF { + fn call(&self, args: &[Expr]) -> Result> { + let args = args + .iter() + .map(|arg| { + if let Expr::Literal(scalar) = arg { + Ok(scalar) + } else { + exec_err!("Expected only literal arguments to table udf") + } + }) + .collect::>>()?; + + if args.len() < 2 { + exec_err!("Expected at least two arguments to table udf")? + } + + let ScalarValue::UInt64(Some(num_rows)) = args[0].to_owned() else { + exec_err!( + "First argument must be the number of elements to create as u64" + )? + }; + let num_rows = num_rows as usize; + + let mut fields = Vec::default(); + let mut arrays1 = Vec::default(); + let mut arrays2 = Vec::default(); + + let split = num_rows / 3; + for (idx, arg) in args[1..].iter().enumerate() { + let (field, array) = match arg { + ScalarValue::Utf8(s) => { + let s_vec = vec![s.to_owned(); num_rows]; + ( + Field::new(format!("field-{}", idx), DataType::Utf8, true), + Arc::new(StringArray::from(s_vec)) as ArrayRef, + ) + } + ScalarValue::UInt64(v) => { + let v_vec = vec![v.to_owned(); num_rows]; + ( + Field::new(format!("field-{}", idx), DataType::UInt64, true), + Arc::new(UInt64Array::from(v_vec)) as ArrayRef, + ) + } + ScalarValue::Float64(v) => { + let v_vec = vec![v.to_owned(); num_rows]; + ( + Field::new(format!("field-{}", idx), DataType::Float64, true), + Arc::new(Float64Array::from(v_vec)) as ArrayRef, + ) + } + _ => exec_err!( + "Test case only supports utf8, u64, and f64. Found {}", + arg.data_type() + )?, + }; + + fields.push(field); + arrays1.push(array.slice(0, split)); + arrays2.push(array.slice(split, num_rows - split)); + } + + let schema = Arc::new(Schema::new(fields)); + let batches = vec![ + RecordBatch::try_new(Arc::clone(&schema), arrays1)?, + RecordBatch::try_new(Arc::clone(&schema), arrays2)?, + ]; + + let table_provider = MemTable::try_new(schema, vec![batches])?; + + Ok(Arc::new(table_provider)) + } + } + + #[tokio::test] + async fn test_round_trip_udtf() -> Result<()> { + let original_udtf = Arc::new(TestUDTF {}) as Arc; + + let local_udtf: FFI_TableFunction = + FFI_TableFunction::new(Arc::clone(&original_udtf), None); + + let foreign_udf: ForeignTableFunction = local_udtf.into(); + + let table = + foreign_udf.call(&vec![lit(6_u64), lit("one"), lit(2.0), lit(3_u64)])?; + + let ctx = SessionContext::default(); + let _ = ctx.register_table("test-table", table)?; + + let returned_batches = ctx.table("test-table").await?.collect().await?; + + assert_eq!(returned_batches.len(), 2); + let expected_batch_0 = record_batch!( + ("field-0", Utf8, ["one", "one"]), + ("field-1", Float64, [2.0, 2.0]), + ("field-2", UInt64, [3, 3]) + )?; + assert_eq!(returned_batches[0], expected_batch_0); + + let expected_batch_1 = record_batch!( + ("field-0", Utf8, ["one", "one", "one", "one"]), + ("field-1", Float64, [2.0, 2.0, 2.0, 2.0]), + ("field-2", UInt64, [3, 3, 3, 3]) + )?; + assert_eq!(returned_batches[1], expected_batch_1); + + Ok(()) + } +} diff --git a/datafusion/ffi/tests/ffi_udtf.rs b/datafusion/ffi/tests/ffi_udtf.rs new file mode 100644 index 000000000000..5a46211d3b9c --- /dev/null +++ b/datafusion/ffi/tests/ffi_udtf.rs @@ -0,0 +1,64 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Add an additional module here for convenience to scope this to only +/// when the feature integtation-tests is built +#[cfg(feature = "integration-tests")] +mod tests { + + use std::sync::Arc; + + use arrow::array::{create_array, ArrayRef}; + use datafusion::error::{DataFusionError, Result}; + use datafusion::prelude::SessionContext; + + use datafusion_ffi::tests::utils::get_module; + use datafusion_ffi::udtf::ForeignTableFunction; + + /// This test validates that we can load an external module and use a scalar + /// udf defined in it via the foreign function interface. In this case we are + /// using the abs() function as our scalar UDF. + #[tokio::test] + async fn test_user_defined_table_function() -> Result<()> { + let module = get_module()?; + + let ffi_table_func = module + .create_table_function() + .ok_or(DataFusionError::NotImplemented( + "External table function provider failed to implement create_table_function" + .to_string(), + ))?(); + let foreign_table_func: ForeignTableFunction = ffi_table_func.into(); + + let udtf = Arc::new(foreign_table_func); + + let ctx = SessionContext::default(); + ctx.register_udtf("my_range", udtf); + + let result = ctx + .sql("SELECT * FROM my_range(5)") + .await? + .collect() + .await?; + let expected = create_array!(Int64, [0, 1, 2, 3, 4]) as ArrayRef; + + assert!(result.len() == 1); + assert!(result[0].column(0) == &expected); + + Ok(()) + } +}