From 86e8c38bd4c5a35a938e4b9efa3244540d67c5f1 Mon Sep 17 00:00:00 2001 From: osipovartem Date: Fri, 24 Oct 2025 17:35:04 +0300 Subject: [PATCH] Register session function and overload other udfs --- crates/core-executor/Cargo.toml | 2 +- crates/core-executor/src/duckdb/functions.rs | 271 +++++++++++++----- crates/core-executor/src/duckdb/query.rs | 67 +---- crates/core-executor/src/query.rs | 55 ++-- crates/core-executor/src/session.rs | 29 ++ crates/embucket-functions/src/datetime/mod.rs | 6 +- crates/embucket-functions/src/lib.rs | 2 +- .../src/session/current_database.rs | 65 +++++ .../src/session/current_schema.rs | 65 +++++ .../src/session/current_version.rs | 65 +++++ crates/embucket-functions/src/session/mod.rs | 49 ++-- crates/embucket-functions/src/tests/utils.rs | 5 +- 12 files changed, 499 insertions(+), 182 deletions(-) create mode 100644 crates/embucket-functions/src/session/current_database.rs create mode 100644 crates/embucket-functions/src/session/current_schema.rs create mode 100644 crates/embucket-functions/src/session/current_version.rs diff --git a/crates/core-executor/Cargo.toml b/crates/core-executor/Cargo.toml index b2cc36c80..d3d530cf1 100644 --- a/crates/core-executor/Cargo.toml +++ b/crates/core-executor/Cargo.toml @@ -19,7 +19,7 @@ chrono = { workspace = true } dashmap = { workspace = true } async-stream = { version = "0.3.6"} #duckdb = { version = "=1.3.1", package = "spiceai_duckdb_fork" } # Forked to add support for duckdb_scan_arrow, pending: https://github.com/duckdb/duckdb-rs/pull/488 -duckdb = { package = "duckdb", version = "1.4.1", features = ["vscalar", "vscalar-arrow", "bundled"] } +duckdb = { package = "duckdb", version = "1.4.1", features = ["vscalar", "vscalar-arrow", "bundled", "vtab", "vtab-arrow"] } #datafusion-table-providers = { version = "0.8.1", features = ["duckdb"] } datafusion = { workspace = true } datafusion-common = { workspace = true } diff --git a/crates/core-executor/src/duckdb/functions.rs b/crates/core-executor/src/duckdb/functions.rs index 26e745001..fe07fbec5 100644 --- a/crates/core-executor/src/duckdb/functions.rs +++ b/crates/core-executor/src/duckdb/functions.rs @@ -1,13 +1,19 @@ use crate::error::{self as ex_error, Result as CoreResult}; -use arrow_schema::{DataType, Field, FieldRef}; +use arrow_schema::{DataType, Field, FieldRef, Schema}; use datafusion::arrow::array::Array; use datafusion::arrow::record_batch::RecordBatch; use datafusion::logical_expr::{ColumnarValue, ScalarFunctionArgs, TypeSignature}; +use datafusion::sql::parser::Statement as DFStatement; use datafusion_common::config::ConfigOptions; use datafusion_common::internal_datafusion_err; use datafusion_expr::{ScalarUDF, ScalarUDFImpl}; use duckdb::Connection; -use duckdb::vscalar::{ArrowFunctionSignature, VArrowScalar}; +use duckdb::core::DataChunkHandle; +use duckdb::vscalar::{ + ArrowFunctionSignature, ArrowScalarParams, ScalarFunctionSignature, VArrowScalar, VScalar, +}; +use duckdb::vtab::arrow::{WritableVector, data_chunk_to_arrow, write_arrow_array_to_vector}; +use duckdb::vtab::to_duckdb_logical_type; use embucket_functions::conditional::{ booland, boolor, boolxor, equal_null, iff, nullifzero, zeroifnull, }; @@ -21,22 +27,77 @@ use embucket_functions::numeric::div0; use embucket_functions::regexp::{ regexp_instr, regexp_like, regexp_replace, regexp_substr, regexp_substr_all, }; +use embucket_functions::session::{current_database, current_schema, current_version}; use embucket_functions::string_binary::{ hex_decode_binary, hex_decode_string, hex_encode, insert, jarowinkler_similarity as js, length, lower, parse_ip, randstr, replace, rtrimmed_length, sha2, split, strtok, substr, }; use embucket_functions::system::{cancel_query, typeof_func}; use snafu::ResultExt; -use std::collections::HashMap; +use sqlparser::ast::{Expr, Ident, ObjectName, Statement, VisitMut, VisitorMut}; +use std::collections::{HashMap, HashSet}; use std::fmt::Debug; use std::hash::BuildHasher; +use std::ops::ControlFlow; use std::{error::Error, sync::Arc}; use strum::IntoEnumIterator; /// Generic adapter between DF and `DuckDB` #[derive(Debug, Clone)] -pub struct DfUdfWrapper { - _marker: std::marker::PhantomData, +pub struct DfUdfWrapper(std::marker::PhantomData); +pub struct SafeVScalar(std::marker::PhantomData); + +impl VScalar for SafeVScalar { + type State = T::State; + + unsafe fn invoke( + info: &Self::State, + input: &mut DataChunkHandle, + out: &mut dyn WritableVector, + ) -> Result<(), Box> { + let rb = if input.num_columns() == 0 { + let row_count = input.len().max(1); + let schema = Arc::new(Schema::empty()); + RecordBatch::try_new_with_options( + schema, + vec![], + &duckdb::arrow::record_batch::RecordBatchOptions::new() + .with_row_count(Some(row_count)), + )? + } else { + data_chunk_to_arrow(input)? + }; + let array = T::invoke(info, rb)?; + write_arrow_array_to_vector(&array, out) + } + + #[allow(clippy::expect_used)] + fn signatures() -> Vec { + T::signatures() + .into_iter() + .map(|sig: ArrowFunctionSignature| { + let params = match sig.parameters { + Some(ArrowScalarParams::Exact(param_types)) => param_types + .into_iter() + .map(|dt| { + to_duckdb_logical_type(&dt).expect("failed to convert parameter type") + }) + .collect(), + Some(ArrowScalarParams::Variadic(param_types)) => { + let converted = to_duckdb_logical_type(¶m_types) + .expect("failed to convert variadic type"); + vec![converted] + } + _ => vec![], + }; + + let ret_type = to_duckdb_logical_type(&sig.return_type) + .expect("failed to convert return type"); + + ScalarFunctionSignature::exact(params, ret_type) + }) + .collect() + } } /// Stores a specific `ScalarUDF` instance for `invoke()` @@ -100,9 +161,7 @@ impl VArrowScalar for DfUdfWrapper { config_options: Arc::new(ConfigOptions::default()), }; - let result = func.invoke_with_args(args_struct)?; - - match result { + match func.invoke_with_args(args_struct)? { ColumnarValue::Array(arr) => Ok(arr), ColumnarValue::Scalar(scalar) => { let array = scalar.to_array_of_size(num_rows)?; @@ -120,125 +179,144 @@ impl VArrowScalar for DfUdfWrapper { pub fn register_all_udfs( conn: &Connection, + statement: &mut DFStatement, udfs: &HashMap, S>, -) -> CoreResult> +) -> CoreResult<()> where - S: BuildHasher, + S: BuildHasher + Clone, { - let mut failed: Vec = Vec::new(); + let mut succeeded: HashSet = HashSet::with_hasher(udfs.hasher().clone()); // String binary - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf_try::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf_try::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf_try::( + conn, + udfs, + &mut succeeded, + )?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf_try::( + conn, + udfs, + &mut succeeded, + )?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; // Regexp - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; // Conditional - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; // Crypto - register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; // Datetime - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; register_duckdb_udf::( conn, udfs, - &mut failed, + &mut succeeded, )?; for interval in Interval::iter() { register_duckdb_udf_internal::( conn, udfs, &interval.to_string(), - &mut failed, + &mut succeeded, )?; } // Numeric - register_duckdb_udf_internal::(conn, udfs, "div0null", &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; + register_duckdb_udf_internal::(conn, udfs, "div0null", &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; // System - register_duckdb_udf::(conn, udfs, &mut failed)?; - register_duckdb_udf::(conn, udfs, &mut failed)?; - Ok(failed) + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + + // Session + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + register_duckdb_udf::(conn, udfs, &mut succeeded)?; + + // Rewrite statement functions with duckdb_ prefix + if let DFStatement::Statement(stmt) = statement { + visit(stmt, &succeeded); + } + Ok(()) } /// Registers a normal (non-try_) UDF in `DuckDB`. pub fn register_duckdb_udf( conn: &Connection, udfs: &HashMap, S>, - failed: &mut Vec, + succeeded: &mut HashSet, ) -> CoreResult<()> where T: ScalarUDFImpl + Default + 'static, - S: BuildHasher, + S: BuildHasher + Clone, { let name = T::default().name().to_string(); - register_duckdb_udf_internal::(conn, udfs, &name, failed) + register_duckdb_udf_internal::(conn, udfs, &name, succeeded) } /// Registers a “try_” variant of a UDF in `DuckDB`. pub fn register_duckdb_udf_try( conn: &Connection, udfs: &HashMap, S>, - failed: &mut Vec, + succeeded: &mut HashSet, ) -> CoreResult<()> where T: ScalarUDFImpl + Default + 'static, - S: BuildHasher, + S: BuildHasher + Clone, { let name = format!("try_{}", T::default().name()); - register_duckdb_udf_internal::(conn, udfs, &name, failed) + register_duckdb_udf_internal::(conn, udfs, &name, succeeded) } /// Shared internal logic for both normal and try_ function registration. +/// We register all functions with duckdb_ prefix fn register_duckdb_udf_internal( conn: &Connection, udfs: &HashMap, S>, name: &str, - failed: &mut Vec, + succeeded: &mut HashSet, ) -> CoreResult<()> where T: ScalarUDFImpl + Default + 'static, - S: BuildHasher, + S: BuildHasher + Clone, { let func = udfs .get(name) @@ -246,11 +324,16 @@ where .context(ex_error::DataFusionSnafu)?; let state = UdfState::new(func.clone()); - if conn - .register_scalar_function_with_state::>(name, &state) - .is_err() + let duckdb_name = &format!("duckdb_{}", func.name()); + match conn + .register_scalar_function_with_state::>>(duckdb_name, &state) { - failed.push(name.to_string()); + Ok(()) => { + succeeded.insert(name.to_string()); + } + Err(_) => { + tracing::error!("Failed to register {duckdb_name} func in duckdb"); + } } Ok(()) } @@ -268,17 +351,18 @@ fn expand_signature( .map(|dt| vec![dt; *arg_count]) .collect::>() } + TypeSignature::VariadicAny => { + let ret = func + .return_type(&[DataType::Utf8]) + .unwrap_or(DataType::Utf8); + return vec![ArrowFunctionSignature::exact(vec![DataType::Utf8], ret)]; + } + TypeSignature::Nullary => { + return vec![ArrowFunctionSignature::exact(vec![], DataType::Utf8)]; + } _ => sig.get_example_types(), }; - if example_sigs.is_empty() { - // Fallback when no examples are available (e.g., for generic or nullary signatures) - let ret = func - .return_type(&[DataType::Utf8]) - .unwrap_or(DataType::Utf8); - return vec![ArrowFunctionSignature::exact(vec![DataType::Utf8], ret)]; - } - // Build a DuckDB signature for each valid argument combination example_sigs .into_iter() @@ -288,3 +372,32 @@ fn expand_signature( }) .collect() } + +/// Rewrites function names to `duckdb_*` equivalents if they were registered in `DuckDB`. +#[derive(Debug)] +pub struct DuckdbFunctionsRewriter<'a, S: BuildHasher> { + pub duckdb_funcs: &'a HashSet, +} + +impl VisitorMut for DuckdbFunctionsRewriter<'_, S> { + type Break = (); + + fn post_visit_expr(&mut self, expr: &mut Expr) -> ControlFlow { + if let Expr::Function(func) = expr { + let func_name_string = func.name.to_string().to_lowercase(); + + if self.duckdb_funcs.contains(&func_name_string) { + let new_name = format!("duckdb_{func_name_string}"); + func.name = ObjectName::from(vec![Ident::new(new_name)]); + } + } + ControlFlow::Continue(()) + } +} + +pub fn visit(stmt: &mut Statement, duckdb_funcs: &HashSet) +where + S: BuildHasher, +{ + let _ = stmt.visit(&mut DuckdbFunctionsRewriter { duckdb_funcs }); +} diff --git a/crates/core-executor/src/duckdb/query.rs b/crates/core-executor/src/duckdb/query.rs index 6376ec3d0..d9fbb8dff 100644 --- a/crates/core-executor/src/duckdb/query.rs +++ b/crates/core-executor/src/duckdb/query.rs @@ -1,39 +1,25 @@ use crate::error::{self as ex_error, Result}; -use async_stream::stream; +use arrow_schema::SchemaRef; use datafusion::arrow::array::RecordBatch; -use datafusion::execution::SendableRecordBatchStream; use datafusion::sql::parser::Statement as DFStatement; -use datafusion_common::DataFusionError; -use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use duckdb::Connection; -use futures::StreamExt; use snafu::ResultExt; use sqlparser::ast::Statement; use std::sync::Arc; -use tokio::sync::mpsc::Sender; -use tokio_stream::wrappers::ReceiverStream; -pub async fn execute_duck_db_explain(conn: Connection, sql: &str) -> Result> { - let sql = sql.to_string(); - - tokio::task::spawn_blocking(move || { - let explain_sql = format!("EXPLAIN (format html) {sql}"); - conn.execute("PRAGMA explain_output = 'all'", []) - .context(ex_error::DuckdbSnafu)?; - let mut stmt = conn.prepare(&explain_sql).context(ex_error::DuckdbSnafu)?; - let arrow = stmt.query_arrow([]).context(ex_error::DuckdbSnafu)?; - Ok(arrow.collect()) - }) - .await - .context(ex_error::JoinHandleSnafu)? +pub fn execute_duck_db_explain(conn: &Connection, sql: &str) -> Result> { + let explain_sql = format!("EXPLAIN (format html) {sql}"); + conn.execute("PRAGMA explain_output = 'all'", []) + .context(ex_error::DuckdbSnafu)?; + let (res, _) = query_duck_db_arrow(conn, &explain_sql)?; + Ok(res) } pub fn query_duck_db_arrow( duckdb_conn: &Connection, sql: &str, -) -> Result { +) -> Result<(Vec, SchemaRef)> { // Clone connection for blocking thread - let conn = duckdb_conn.try_clone().context(crate::error::DuckdbSnafu)?; let sql = sql.to_string(); // Prepare statement and get schema @@ -42,43 +28,10 @@ pub fn query_duck_db_arrow( .context(crate::error::DuckdbSnafu)?; let result: duckdb::Arrow<'_> = stmt.query_arrow([]).context(crate::error::DuckdbSnafu)?; let schema = result.get_schema(); - - // Create async channel for record batches - let (batch_tx, batch_rx) = tokio::sync::mpsc::channel::(4); - let cloned_schema = schema.clone(); - - let join_handle = tokio::task::spawn_blocking(move || -> Result<()> { - let mut stmt = conn.prepare(&sql).context(crate::error::DuckdbSnafu)?; - let result: duckdb::ArrowStream<'_> = stmt - .stream_arrow([], cloned_schema) - .context(crate::error::DuckdbSnafu)?; - for batch in result { - blocking_channel_send(&batch_tx, batch)?; - } - Ok(()) - }); - - let stream = ReceiverStream::new(batch_rx) - .map(Ok) - .chain(stream! { - match join_handle.await { - Ok(Err(e)) => yield Err(DataFusionError::Execution(format!("DuckDB query failed: {e}"))), - Err(join_err) => yield Err(DataFusionError::Execution(format!("DuckDB thread join failed: {join_err}"))), - _ => {} - } - }); - Ok(Box::pin(RecordBatchStreamAdapter::new( - schema, - Box::pin(stream), - ))) + let res = result.collect(); + Ok((res, schema)) } -fn blocking_channel_send(channel: &Sender, item: T) -> Result<()> { - channel - .blocking_send(item) - .map_err(|e| DataFusionError::Execution(e.to_string())) - .context(ex_error::DataFusionSnafu) -} #[must_use] pub fn is_select_statement(stmt: &DFStatement) -> bool { matches!(stmt, DFStatement::Statement(inner) if matches!(**inner, Statement::Query(_))) diff --git a/crates/core-executor/src/query.rs b/crates/core-executor/src/query.rs index 4d15b1946..bdff28464 100644 --- a/crates/core-executor/src/query.rs +++ b/crates/core-executor/src/query.rs @@ -85,6 +85,7 @@ use df_catalog::catalog_list::CachedEntity; use df_catalog::table::CachingTable; use duckdb::Connection; use embucket_functions::semi_structured::variant::visitors::visit_all; +use embucket_functions::session::session_prop; use embucket_functions::session_params::SessionProperty; use embucket_functions::visitors::{ copy_into_identifiers, fetch_to_limit, functions_rewriter, inline_aliases_in_query, @@ -92,7 +93,6 @@ use embucket_functions::visitors::{ table_functions_cte_relation, timestamp, top_limit, unimplemented::functions_checker::visit as unimplemented_functions_checker, }; -use futures::TryStreamExt; use iceberg_rust::catalog::Catalog; use iceberg_rust::catalog::create::CreateTableBuilder; use iceberg_rust::catalog::identifier::Identifier; @@ -200,6 +200,10 @@ impl UserQuery { .unwrap_or_else(|| "public".to_string()) } + fn current_version(&self) -> String { + self.session.config.embucket_version.clone() + } + #[instrument( name = "UserQuery::refresh_catalog_partially", level = "debug", @@ -508,18 +512,21 @@ impl UserQuery { // Convert already resolved table references to iceberg_scan function call let setup_queries = self.update_iceberg_scan_references(&mut statement).await?; - self.query = statement.to_string(); - let sql = self.query.clone(); let conn = Connection::open_in_memory().context(ex_error::DuckdbSnafu)?; - let failed = register_all_udfs(&conn, self.session.ctx.state().scalar_functions())?; - if !failed.is_empty() { - tracing::warn!( - "Some UDFs were not registered/overloaded in DuckDB: {:?}", - failed - ); - } + // Set session params for session UDFs and register Embucket UDFs in Duckdb connection + self.set_duckdb_udf_params()?; + register_all_udfs( + &conn, + &mut statement, + self.session.ctx.state().scalar_functions(), + )?; + + self.query = statement.to_string(); + let sql = self.query.clone(); + + // Apply setup queries apply_connection_setup_queries(&conn, &setup_queries)?; if self.session.config.use_duck_db_explain @@ -527,17 +534,11 @@ impl UserQuery { .session .get_session_variable_bool("embucket.execution.explain_before_acceleration") { - // Check if possible to call duckdb with this query - let explain_conn = conn.try_clone().context(ex_error::DuckdbSnafu)?; - let _explain_result = execute_duck_db_explain(explain_conn, &sql).await?; + // Run EXPLAIN before query execution to check if the query is supported + let _explain_result = execute_duck_db_explain(&conn, &sql)?; } - let stream = query_duck_db_arrow(&conn, &sql)?; - let schema = stream.schema().clone(); - let records = stream - .try_collect::>() - .await - .context(ex_error::DataFusionSnafu)?; + let (records, schema) = query_duck_db_arrow(&conn, &sql)?; Ok::(QueryResult::new( records, schema, @@ -545,6 +546,22 @@ impl UserQuery { )) } + #[instrument( + name = "UserQuery::set_duckdb_udf_params", + level = "trace", + skip(self), + err + )] + pub fn set_duckdb_udf_params(&self) -> Result<()> { + let params = HashMap::from([ + (session_prop("current_database"), self.current_database()), + (session_prop("current_schema"), self.current_schema()), + (session_prop("current_version"), self.current_version()), + ]); + self.session + .set_session_params(params, &self.session.ctx.session_id()) + } + #[instrument(name = "UserQuery::get_catalog", level = "trace", skip(self), err)] pub fn get_catalog(&self, name: &str) -> Result> { self.session diff --git a/crates/core-executor/src/session.rs b/crates/core-executor/src/session.rs index 19aaf197b..fa1905626 100644 --- a/crates/core-executor/src/session.rs +++ b/crates/core-executor/src/session.rs @@ -190,6 +190,35 @@ impl UserSession { Ok(()) } + pub fn set_session_params( + &self, + properties: HashMap, + session_id: &str, + ) -> Result<()> { + let state = self.ctx.state_ref(); + let mut write = state.write(); + let options = write.config_mut().options_mut(); + + let properties = properties + .into_iter() + .map(|(name, value)| { + let prop = SessionProperty::from_str_value( + name.clone(), + value, + Some(session_id.to_string()), + ); + (name, prop) + }) + .collect(); + + let config = options.extensions.get_mut::(); + if let Some(cfg) = config { + cfg.set_properties(properties) + .context(ex_error::DataFusionSnafu)?; + } + Ok(()) + } + #[must_use] pub fn get_session_variable(&self, variable: &str) -> Option { let state = self.ctx.state(); diff --git a/crates/embucket-functions/src/datetime/mod.rs b/crates/embucket-functions/src/datetime/mod.rs index 6f56d39d2..bd3dbfc6a 100644 --- a/crates/embucket-functions/src/datetime/mod.rs +++ b/crates/embucket-functions/src/datetime/mod.rs @@ -33,6 +33,9 @@ pub fn register_udfs( date_from_parts::get_udf(), dayname::get_udf(), Arc::new(ScalarUDF::from(LastDayFunc::new(session_params.clone()))), + Arc::new(ScalarUDF::from(ConvertTimezoneFunc::new( + session_params.clone(), + ))), monthname::get_udf(), next_day::get_udf(), previous_day::get_udf(), @@ -44,8 +47,5 @@ pub fn register_udfs( registry.register_udf(func)?; } date_part_extract::register_udfs(registry, session_params)?; - registry.register_udf(Arc::new(ScalarUDF::from(ConvertTimezoneFunc::new( - session_params.to_owned(), - ))))?; Ok(()) } diff --git a/crates/embucket-functions/src/lib.rs b/crates/embucket-functions/src/lib.rs index 5dc3d0e98..c6c17bc3a 100644 --- a/crates/embucket-functions/src/lib.rs +++ b/crates/embucket-functions/src/lib.rs @@ -64,7 +64,7 @@ pub fn register_udfs( semi_structured::register_udfs(registry)?; regexp::register_udfs(registry)?; system::register_udfs(registry)?; - session::register_session_context_udfs(registry)?; + session::register_session_context_udfs(registry, session_params)?; window::register_udwfs(registry)?; Ok(()) } diff --git a/crates/embucket-functions/src/session/current_database.rs b/crates/embucket-functions/src/session/current_database.rs new file mode 100644 index 000000000..e43051431 --- /dev/null +++ b/crates/embucket-functions/src/session/current_database.rs @@ -0,0 +1,65 @@ +use crate::session::session_prop; +use crate::session_params::SessionParams; +use datafusion::arrow::array::StringArray; +use datafusion::arrow::datatypes::DataType; +use datafusion::error::Result as DFResult; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +/// Returns the name of the current database, which varies depending on where you call the function +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct CurrentDatabase { + signature: Signature, + session_params: Arc, +} + +impl Default for CurrentDatabase { + fn default() -> Self { + Self::new(Arc::new(SessionParams::default())) + } +} + +impl CurrentDatabase { + #[must_use] + pub fn new(session_params: Arc) -> Self { + Self { + signature: Signature::nullary(Volatility::Stable), + session_params, + } + } + + #[must_use] + pub fn current_database(&self) -> String { + self.session_params + .get_property(&session_prop("current_database")) + .unwrap_or_else(|| "embucket".to_string()) + } +} + +impl ScalarUDFImpl for CurrentDatabase { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &'static str { + "current_database" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let num_rows = args.number_rows; + let value = self.current_database(); + let array = Arc::new(StringArray::from(vec![Some(value.as_str()); num_rows])); + Ok(ColumnarValue::Array(array)) + } +} + +crate::macros::make_udf_function!(CurrentDatabase); diff --git a/crates/embucket-functions/src/session/current_schema.rs b/crates/embucket-functions/src/session/current_schema.rs new file mode 100644 index 000000000..a1d2da823 --- /dev/null +++ b/crates/embucket-functions/src/session/current_schema.rs @@ -0,0 +1,65 @@ +use crate::session::session_prop; +use crate::session_params::SessionParams; +use datafusion::arrow::array::StringArray; +use datafusion::arrow::datatypes::DataType; +use datafusion::error::Result as DFResult; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +/// Returns the name of the current schema, which varies depending on where you call the function +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct CurrentSchema { + signature: Signature, + session_params: Arc, +} + +impl Default for CurrentSchema { + fn default() -> Self { + Self::new(Arc::new(SessionParams::default())) + } +} + +impl CurrentSchema { + #[must_use] + pub fn new(session_params: Arc) -> Self { + Self { + signature: Signature::nullary(Volatility::Stable), + session_params, + } + } + + #[must_use] + pub fn current_schema(&self) -> String { + self.session_params + .get_property(&session_prop("current_schema")) + .unwrap_or_else(|| "embucket".to_string()) + } +} + +impl ScalarUDFImpl for CurrentSchema { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &'static str { + "current_schema" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let num_rows = args.number_rows; + let value = self.current_schema(); + let array = Arc::new(StringArray::from(vec![Some(value.as_str()); num_rows])); + Ok(ColumnarValue::Array(array)) + } +} + +crate::macros::make_udf_function!(CurrentSchema); diff --git a/crates/embucket-functions/src/session/current_version.rs b/crates/embucket-functions/src/session/current_version.rs new file mode 100644 index 000000000..d1f8fa7af --- /dev/null +++ b/crates/embucket-functions/src/session/current_version.rs @@ -0,0 +1,65 @@ +use crate::session::session_prop; +use crate::session_params::SessionParams; +use datafusion::arrow::array::StringArray; +use datafusion::arrow::datatypes::DataType; +use datafusion::error::Result as DFResult; +use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use std::any::Any; +use std::sync::Arc; + +/// Returns the current Embucket version. +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct CurrentVersion { + signature: Signature, + session_params: Arc, +} + +impl Default for CurrentVersion { + fn default() -> Self { + Self::new(Arc::new(SessionParams::default())) + } +} + +impl CurrentVersion { + #[must_use] + pub fn new(session_params: Arc) -> Self { + Self { + signature: Signature::exact(vec![], Volatility::Volatile), + session_params, + } + } + + #[must_use] + pub fn current_version(&self) -> String { + self.session_params + .get_property(&session_prop("current_version")) + .unwrap_or_else(|| env!("CARGO_PKG_VERSION").to_string()) + } +} + +impl ScalarUDFImpl for CurrentVersion { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &'static str { + "current_version" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult { + let num_rows = args.number_rows; + let value = self.current_version(); + let array = Arc::new(StringArray::from(vec![Some(value.as_str()); num_rows])); + Ok(ColumnarValue::Array(array)) + } +} + +crate::macros::make_udf_function!(CurrentVersion); diff --git a/crates/embucket-functions/src/session/mod.rs b/crates/embucket-functions/src/session/mod.rs index a93edfd06..8369133fb 100644 --- a/crates/embucket-functions/src/session/mod.rs +++ b/crates/embucket-functions/src/session/mod.rs @@ -1,5 +1,9 @@ +pub mod current_database; +pub mod current_schema; +pub mod current_version; mod last_query_id; +use crate::session_params::SessionParams; use datafusion::arrow::array::ListArray; use datafusion::arrow::datatypes::{DataType, Field}; use datafusion_common::{Result, ScalarValue}; @@ -21,16 +25,6 @@ macro_rules! create_session_context_udf { }}; } -/// Returns the name of the current database, which varies depending on where you call the function -fn current_database_udf() -> ScalarUDF { - create_session_context_udf!("current_database", "default") -} - -/// Returns the name of the current schema, which varies depending on where you call the function -fn current_schema_udf() -> ScalarUDF { - create_session_context_udf!("current_schema", "default") -} - /// Returns active search path schemas. fn current_schemas_udf() -> ScalarUDF { let fun: ScalarFunctionImplementation = Arc::new(move |_args| { @@ -53,11 +47,6 @@ fn current_warehouse_udf() -> ScalarUDF { create_session_context_udf!("current_warehouse", "default") } -/// Returns the current Embucket version. -fn current_version_udf() -> ScalarUDF { - create_session_context_udf!("current_version", env!("CARGO_PKG_VERSION")) -} - /// Returns the version of the client from which the function was called. fn current_client_udf() -> ScalarUDF { let version = format!("Embucket {}", env!("CARGO_PKG_VERSION")); @@ -86,23 +75,43 @@ fn current_ip_address_udf() -> ScalarUDF { create_session_context_udf!("current_ip_address", "") } -pub fn register_session_context_udfs(registry: &mut dyn FunctionRegistry) -> Result<()> { +pub fn register_session_context_udfs( + registry: &mut dyn FunctionRegistry, + session_params: &Arc, +) -> Result<()> { let udfs = [ current_client_udf(), - current_database_udf(), current_ip_address_udf(), current_role_udf(), current_role_type_udf(), - current_schema_udf(), current_schemas_udf(), current_session_udf(), - current_version_udf(), current_warehouse_udf(), ]; for udf in udfs { registry.register_udf(udf.into())?; } - registry.register_udf(last_query_id::get_udf())?; + + let functions: Vec> = vec![ + last_query_id::get_udf(), + Arc::new(ScalarUDF::from(current_version::CurrentVersion::new( + session_params.clone(), + ))), + Arc::new(ScalarUDF::from(current_database::CurrentDatabase::new( + session_params.clone(), + ))), + Arc::new(ScalarUDF::from(current_schema::CurrentSchema::new( + session_params.clone(), + ))), + ]; + for func in functions { + registry.register_udf(func)?; + } Ok(()) } + +#[must_use] +pub fn session_prop(property: &str) -> String { + format!("embucket.session.{property}") +} diff --git a/crates/embucket-functions/src/tests/utils.rs b/crates/embucket-functions/src/tests/utils.rs index 57a8aa1de..3c598ebf5 100644 --- a/crates/embucket-functions/src/tests/utils.rs +++ b/crates/embucket-functions/src/tests/utils.rs @@ -24,8 +24,9 @@ pub fn create_session() -> Arc { .with_expr_planners(vec![Arc::new(CustomExprPlanner)]) .build(); let mut ctx = SessionContext::new_with_state(state); - register_session_context_udfs(&mut ctx).unwrap(); - register_udfs(&mut ctx, &Arc::new(SessionParams::default())).expect("Cannot register UDFs"); + let session_params = Arc::new(SessionParams::default()); + register_session_context_udfs(&mut ctx, &session_params).unwrap(); + register_udfs(&mut ctx, &session_params).expect("Cannot register UDFs"); register_udafs(&mut ctx).expect("Cannot register UDAFs"); register_udtfs(&ctx, history_store_mock()); Arc::new(ctx)