diff --git a/crates/core-executor/src/duckdb/functions.rs b/crates/core-executor/src/duckdb/functions.rs index e105c4f99..26e745001 100644 --- a/crates/core-executor/src/duckdb/functions.rs +++ b/crates/core-executor/src/duckdb/functions.rs @@ -2,30 +2,76 @@ use crate::error::{self as ex_error, Result as CoreResult}; use arrow_schema::{DataType, Field, FieldRef}; use datafusion::arrow::array::Array; use datafusion::arrow::record_batch::RecordBatch; -use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; +use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, TypeSignature}; use datafusion_common::config::ConfigOptions; +use datafusion_common::internal_datafusion_err; +use datafusion_expr::{ScalarUDF, ScalarUDFImpl}; use duckdb::Connection; use duckdb::vscalar::{ArrowFunctionSignature, VArrowScalar}; -use embucket_functions::string_binary::length::LengthFunc; +use embucket_functions::conditional::{ + booland, boolor, boolxor, equal_null, iff, nullifzero, zeroifnull, +}; +use embucket_functions::crypto::md5; +use embucket_functions::datetime::date_part_extract::Interval; +use embucket_functions::datetime::{ + add_months, date_add, date_diff, date_from_parts, date_part_extract, dayname, last_day, + monthname, next_day, previous_day, time_from_parts, timestamp_from_parts, +}; +use embucket_functions::numeric::div0; +use embucket_functions::regexp::{ + regexp_instr, regexp_like, regexp_replace, regexp_substr, regexp_substr_all, +}; +use embucket_functions::string_binary::{ + hex_decode_binary, hex_decode_string, hex_encode, insert, jarowinkler_similarity as js, length, + lower, parse_ip, randstr, replace, rtrimmed_length, sha2, split, strtok, substr, +}; +use embucket_functions::system::{cancel_query, typeof_func}; use snafu::ResultExt; +use std::collections::HashMap; +use std::fmt::Debug; +use std::hash::BuildHasher; use std::{error::Error, sync::Arc}; -pub struct DfUdfWrapper { - _inner: T, +use strum::IntoEnumIterator; + +/// Generic adapter between DF and `DuckDB` +#[derive(Debug, Clone)] +pub struct DfUdfWrapper { + _marker: std::marker::PhantomData, +} + +/// Stores a specific `ScalarUDF` instance for `invoke()` +#[derive(Clone)] +pub struct UdfState { + udf: Arc, } -impl DfUdfWrapper { - pub const fn new(inner: T) -> Self { - Self { _inner: inner } +impl UdfState { + #[must_use] + pub const fn new(udf: Arc) -> Self { + Self { udf } + } + + #[must_use] + pub fn udf(&self) -> Arc { + self.udf.clone() + } +} + +impl Default for UdfState { + fn default() -> Self { + // dummy placeholder (will be overridden by register_with_state) + let fake = length::get_udf(); + Self::new(fake) } } impl VArrowScalar for DfUdfWrapper { - type State = (); + type State = UdfState; - fn invoke(_state: &Self::State, input: RecordBatch) -> Result, Box> { + fn invoke(state: &Self::State, input: RecordBatch) -> Result, Box> { let num_rows = input.num_rows(); let schema = input.schema(); - let func = T::default(); + let func = state.udf(); let args: Vec = input .columns() .iter() @@ -68,38 +114,177 @@ impl VArrowScalar for DfUdfWrapper { fn signatures() -> Vec { let func = T::default(); let sig = func.signature(); + expand_signature(&func, &sig.type_signature) + } +} - match &sig.type_signature { - datafusion::logical_expr::TypeSignature::Exact(types) => { - vec![ArrowFunctionSignature::exact( - types.clone(), - func.return_type(types).unwrap_or(DataType::Utf8), - )] - } - datafusion::logical_expr::TypeSignature::Variadic(valid_types) => { - vec![ArrowFunctionSignature::exact( - vec![valid_types.first().cloned().unwrap_or(DataType::Utf8)], - func.return_type(&[valid_types.first().cloned().unwrap_or(DataType::Utf8)]) - .unwrap_or(DataType::Utf8), - )] - } - datafusion::logical_expr::TypeSignature::Any(n) => { - let args = vec![DataType::Utf8; *n]; - let ret = func.return_type(&args).unwrap_or(DataType::Utf8); - vec![ArrowFunctionSignature::exact(args, ret)] - } - _ => { - let ret = func - .return_type(&[DataType::Utf8]) - .unwrap_or(DataType::Utf8); - vec![ArrowFunctionSignature::exact(vec![DataType::Utf8], ret)] - } - } +pub fn register_all_udfs( + conn: &Connection, + udfs: &HashMap, S>, +) -> CoreResult> +where + S: BuildHasher, +{ + let mut failed: Vec = Vec::new(); + + // String binary + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf_try::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf_try::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + + // Regexp + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + + // Conditional + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + + // Crypto + register_duckdb_udf::(conn, udfs, &mut failed)?; + + // Datetime + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::( + conn, + udfs, + &mut failed, + )?; + for interval in Interval::iter() { + register_duckdb_udf_internal::( + conn, + udfs, + &interval.to_string(), + &mut failed, + )?; } + + // Numeric + register_duckdb_udf_internal::(conn, udfs, "div0null", &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + + // System + register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut failed)?; + Ok(failed) +} + +/// Registers a normal (non-try_) UDF in `DuckDB`. +pub fn register_duckdb_udf( + conn: &Connection, + udfs: &HashMap, S>, + failed: &mut Vec, +) -> CoreResult<()> +where + T: ScalarUDFImpl + Default + 'static, + S: BuildHasher, +{ + let name = T::default().name().to_string(); + register_duckdb_udf_internal::(conn, udfs, &name, failed) +} + +/// Registers a “try_” variant of a UDF in `DuckDB`. +pub fn register_duckdb_udf_try( + conn: &Connection, + udfs: &HashMap, S>, + failed: &mut Vec, +) -> CoreResult<()> +where + T: ScalarUDFImpl + Default + 'static, + S: BuildHasher, +{ + let name = format!("try_{}", T::default().name()); + register_duckdb_udf_internal::(conn, udfs, &name, failed) } -pub fn register_all_udfs(conn: &Connection) -> CoreResult<()> { - conn.register_scalar_function::>("length_test") - .context(ex_error::DuckdbSnafu)?; +/// Shared internal logic for both normal and try_ function registration. +fn register_duckdb_udf_internal( + conn: &Connection, + udfs: &HashMap, S>, + name: &str, + failed: &mut Vec, +) -> CoreResult<()> +where + T: ScalarUDFImpl + Default + 'static, + S: BuildHasher, +{ + let func = udfs + .get(name) + .ok_or_else(|| internal_datafusion_err!("Unable to find expected '{name}' function")) + .context(ex_error::DataFusionSnafu)?; + + let state = UdfState::new(func.clone()); + if conn + .register_scalar_function_with_state::>(name, &state) + .is_err() + { + failed.push(name.to_string()); + } Ok(()) } + +fn expand_signature( + func: &T, + sig: &TypeSignature, +) -> Vec { + // DataFusion already knows all valid argument type combinations for this signature + let example_sigs = match sig { + TypeSignature::Any(arg_count) => { + let types = vec![DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View]; + types + .into_iter() + .map(|dt| vec![dt; *arg_count]) + .collect::>() + } + _ => sig.get_example_types(), + }; + + if example_sigs.is_empty() { + // Fallback when no examples are available (e.g., for generic or nullary signatures) + let ret = func + .return_type(&[DataType::Utf8]) + .unwrap_or(DataType::Utf8); + return vec![ArrowFunctionSignature::exact(vec![DataType::Utf8], ret)]; + } + + // Build a DuckDB signature for each valid argument combination + example_sigs + .into_iter() + .map(|types| { + let ret = func.return_type(&types).unwrap_or(DataType::Utf8); + ArrowFunctionSignature::exact(types, ret) + }) + .collect() +} diff --git a/crates/core-executor/src/query.rs b/crates/core-executor/src/query.rs index ad23bf63b..4d15b1946 100644 --- a/crates/core-executor/src/query.rs +++ b/crates/core-executor/src/query.rs @@ -512,7 +512,14 @@ impl UserQuery { let sql = self.query.clone(); let conn = Connection::open_in_memory().context(ex_error::DuckdbSnafu)?; - register_all_udfs(&conn)?; + let failed = register_all_udfs(&conn, self.session.ctx.state().scalar_functions())?; + if !failed.is_empty() { + tracing::warn!( + "Some UDFs were not registered/overloaded in DuckDB: {:?}", + failed + ); + } + apply_connection_setup_queries(&conn, &setup_queries)?; if self.session.config.use_duck_db_explain diff --git a/crates/embucket-functions/src/datetime/date_part_extract.rs b/crates/embucket-functions/src/datetime/date_part_extract.rs index ed6d1c8f2..04278964a 100644 --- a/crates/embucket-functions/src/datetime/date_part_extract.rs +++ b/crates/embucket-functions/src/datetime/date_part_extract.rs @@ -13,6 +13,7 @@ use datafusion_expr::registry::FunctionRegistry; use datafusion_expr::{ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility}; use snafu::OptionExt; use std::any::Any; +use std::fmt; use std::sync::Arc; use strum::IntoEnumIterator; use strum_macros::EnumIter; @@ -37,6 +38,30 @@ pub enum Interval { Second, } +impl fmt::Display for Interval { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + Self::Year => "year", + Self::YearOfWeek => "yearofweek", + Self::YearOfWeekIso => "yearofweekiso", + Self::Day => "day", + Self::DayOfMonth => "dayofmonth", + Self::DayOfWeek => "dayofweek", + Self::DayOfWeekIso => "dayofweekiso", + Self::DayOfYear => "dayofyear", + Self::Week => "week", + Self::WeekOfYear => "weekofyear", + Self::WeekIso => "weekiso", + Self::Month => "month", + Self::Quarter => "quarter", + Self::Hour => "hour", + Self::Minute => "minute", + Self::Second => "second", + }; + write!(f, "{s}") + } +} + /// `YEAR*` / `DAY*` / `WEEK*` / `MONTH` / `QUARTER` / `HOUR` / `MINUTE` / `SECOND` SQL function /// /// Extracts a specific part of a date or timestamp. diff --git a/crates/embucket-functions/src/regexp/mod.rs b/crates/embucket-functions/src/regexp/mod.rs index f3a3ca384..03396d3e3 100644 --- a/crates/embucket-functions/src/regexp/mod.rs +++ b/crates/embucket-functions/src/regexp/mod.rs @@ -1,9 +1,9 @@ pub mod errors; pub mod regexp_instr; -mod regexp_like; -mod regexp_replace; -mod regexp_substr; -mod regexp_substr_all; +pub mod regexp_like; +pub mod regexp_replace; +pub mod regexp_substr; +pub mod regexp_substr_all; use crate::regexp::regexp_instr::RegexpInstrFunc; use crate::regexp::regexp_like::RegexpLikeFunc; diff --git a/crates/embucket-functions/src/regexp/regexp_like.rs b/crates/embucket-functions/src/regexp/regexp_like.rs index 63221191f..88f2407f8 100644 --- a/crates/embucket-functions/src/regexp/regexp_like.rs +++ b/crates/embucket-functions/src/regexp/regexp_like.rs @@ -55,6 +55,7 @@ impl Default for RegexpLikeFunc { } impl RegexpLikeFunc { + #[must_use] pub fn new() -> Self { Self { signature: Signature::one_of( diff --git a/crates/embucket-functions/src/regexp/regexp_replace.rs b/crates/embucket-functions/src/regexp/regexp_replace.rs index 5a54b7870..a9f807d6c 100644 --- a/crates/embucket-functions/src/regexp/regexp_replace.rs +++ b/crates/embucket-functions/src/regexp/regexp_replace.rs @@ -59,6 +59,7 @@ impl Default for RegexpReplaceFunc { } impl RegexpReplaceFunc { + #[must_use] pub fn new() -> Self { Self { signature: Signature::one_of( diff --git a/crates/embucket-functions/src/regexp/regexp_substr.rs b/crates/embucket-functions/src/regexp/regexp_substr.rs index 2c403960c..370b9d8d8 100644 --- a/crates/embucket-functions/src/regexp/regexp_substr.rs +++ b/crates/embucket-functions/src/regexp/regexp_substr.rs @@ -65,6 +65,7 @@ impl Default for RegexpSubstrFunc { } impl RegexpSubstrFunc { + #[must_use] pub fn new() -> Self { Self { signature: Signature::one_of( diff --git a/crates/embucket-functions/src/regexp/regexp_substr_all.rs b/crates/embucket-functions/src/regexp/regexp_substr_all.rs index 8681fbd52..3695ce6bc 100644 --- a/crates/embucket-functions/src/regexp/regexp_substr_all.rs +++ b/crates/embucket-functions/src/regexp/regexp_substr_all.rs @@ -67,6 +67,7 @@ impl Default for RegexpSubstrAllFunc { } impl RegexpSubstrAllFunc { + #[must_use] pub fn new() -> Self { Self { signature: Signature::one_of( diff --git a/crates/embucket-functions/src/string-binary/parse_ip.rs b/crates/embucket-functions/src/string-binary/parse_ip.rs index cb76279c2..ee512648c 100644 --- a/crates/embucket-functions/src/string-binary/parse_ip.rs +++ b/crates/embucket-functions/src/string-binary/parse_ip.rs @@ -7,9 +7,7 @@ use datafusion::logical_expr::{ }; use datafusion_common::ScalarValue; use datafusion_common::cast::as_generic_string_array; -use datafusion_common::types::{ - NativeType, logical_float16, logical_float32, logical_float64, logical_string, -}; +use datafusion_common::types::{NativeType, logical_float32, logical_float64, logical_string}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; use ipnet::IpNet; use serde_json::json; @@ -65,7 +63,6 @@ impl ParseIpFunc { TypeSignatureClass::Native(logical_float64()), vec![ TypeSignatureClass::Integer, - TypeSignatureClass::Native(logical_float16()), TypeSignatureClass::Native(logical_float32()), ], NativeType::Float64,