Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/core-executor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ chrono = { workspace = true }
dashmap = { workspace = true }
async-stream = { version = "0.3.6"}
#duckdb = { version = "=1.3.1", package = "spiceai_duckdb_fork" } # Forked to add support for duckdb_scan_arrow, pending: https://github.com/duckdb/duckdb-rs/pull/488
duckdb = { package = "duckdb", version = "1.4.1", features = ["vscalar", "vscalar-arrow", "bundled"] }
duckdb = { package = "duckdb", version = "1.4.1", features = ["vscalar", "vscalar-arrow", "bundled", "vtab", "vtab-arrow"] }
#datafusion-table-providers = { version = "0.8.1", features = ["duckdb"] }
datafusion = { workspace = true }
datafusion-common = { workspace = true }
Expand Down
271 changes: 192 additions & 79 deletions crates/core-executor/src/duckdb/functions.rs

Large diffs are not rendered by default.

67 changes: 10 additions & 57 deletions crates/core-executor/src/duckdb/query.rs
Original file line number Diff line number Diff line change
@@ -1,39 +1,25 @@
use crate::error::{self as ex_error, Result};
use async_stream::stream;
use arrow_schema::SchemaRef;
use datafusion::arrow::array::RecordBatch;
use datafusion::execution::SendableRecordBatchStream;
use datafusion::sql::parser::Statement as DFStatement;
use datafusion_common::DataFusionError;
use datafusion_physical_plan::stream::RecordBatchStreamAdapter;
use duckdb::Connection;
use futures::StreamExt;
use snafu::ResultExt;
use sqlparser::ast::Statement;
use std::sync::Arc;
use tokio::sync::mpsc::Sender;
use tokio_stream::wrappers::ReceiverStream;

pub async fn execute_duck_db_explain(conn: Connection, sql: &str) -> Result<Vec<RecordBatch>> {
let sql = sql.to_string();

tokio::task::spawn_blocking(move || {
let explain_sql = format!("EXPLAIN (format html) {sql}");
conn.execute("PRAGMA explain_output = 'all'", [])
.context(ex_error::DuckdbSnafu)?;
let mut stmt = conn.prepare(&explain_sql).context(ex_error::DuckdbSnafu)?;
let arrow = stmt.query_arrow([]).context(ex_error::DuckdbSnafu)?;
Ok(arrow.collect())
})
.await
.context(ex_error::JoinHandleSnafu)?
pub fn execute_duck_db_explain(conn: &Connection, sql: &str) -> Result<Vec<RecordBatch>> {
let explain_sql = format!("EXPLAIN (format html) {sql}");
conn.execute("PRAGMA explain_output = 'all'", [])
.context(ex_error::DuckdbSnafu)?;
let (res, _) = query_duck_db_arrow(conn, &explain_sql)?;
Ok(res)
}

pub fn query_duck_db_arrow(
duckdb_conn: &Connection,
sql: &str,
) -> Result<SendableRecordBatchStream> {
) -> Result<(Vec<RecordBatch>, SchemaRef)> {
// Clone connection for blocking thread
let conn = duckdb_conn.try_clone().context(crate::error::DuckdbSnafu)?;
let sql = sql.to_string();

// Prepare statement and get schema
Expand All @@ -42,43 +28,10 @@ pub fn query_duck_db_arrow(
.context(crate::error::DuckdbSnafu)?;
let result: duckdb::Arrow<'_> = stmt.query_arrow([]).context(crate::error::DuckdbSnafu)?;
let schema = result.get_schema();

// Create async channel for record batches
let (batch_tx, batch_rx) = tokio::sync::mpsc::channel::<RecordBatch>(4);
let cloned_schema = schema.clone();

let join_handle = tokio::task::spawn_blocking(move || -> Result<()> {
let mut stmt = conn.prepare(&sql).context(crate::error::DuckdbSnafu)?;
let result: duckdb::ArrowStream<'_> = stmt
.stream_arrow([], cloned_schema)
.context(crate::error::DuckdbSnafu)?;
for batch in result {
blocking_channel_send(&batch_tx, batch)?;
}
Ok(())
});

let stream = ReceiverStream::new(batch_rx)
.map(Ok)
.chain(stream! {
match join_handle.await {
Ok(Err(e)) => yield Err(DataFusionError::Execution(format!("DuckDB query failed: {e}"))),
Err(join_err) => yield Err(DataFusionError::Execution(format!("DuckDB thread join failed: {join_err}"))),
_ => {}
}
});
Ok(Box::pin(RecordBatchStreamAdapter::new(
schema,
Box::pin(stream),
)))
let res = result.collect();
Ok((res, schema))
}

fn blocking_channel_send<T>(channel: &Sender<T>, item: T) -> Result<()> {
channel
.blocking_send(item)
.map_err(|e| DataFusionError::Execution(e.to_string()))
.context(ex_error::DataFusionSnafu)
}
#[must_use]
pub fn is_select_statement(stmt: &DFStatement) -> bool {
matches!(stmt, DFStatement::Statement(inner) if matches!(**inner, Statement::Query(_)))
Expand Down
55 changes: 36 additions & 19 deletions crates/core-executor/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,14 @@ use df_catalog::catalog_list::CachedEntity;
use df_catalog::table::CachingTable;
use duckdb::Connection;
use embucket_functions::semi_structured::variant::visitors::visit_all;
use embucket_functions::session::session_prop;
use embucket_functions::session_params::SessionProperty;
use embucket_functions::visitors::{
copy_into_identifiers, fetch_to_limit, functions_rewriter, inline_aliases_in_query,
like_ilike_any, rlike_regexp_expr_rewriter, select_expr_aliases, table_functions,
table_functions_cte_relation, timestamp, top_limit,
unimplemented::functions_checker::visit as unimplemented_functions_checker,
};
use futures::TryStreamExt;
use iceberg_rust::catalog::Catalog;
use iceberg_rust::catalog::create::CreateTableBuilder;
use iceberg_rust::catalog::identifier::Identifier;
Expand Down Expand Up @@ -200,6 +200,10 @@ impl UserQuery {
.unwrap_or_else(|| "public".to_string())
}

fn current_version(&self) -> String {
self.session.config.embucket_version.clone()
}

#[instrument(
name = "UserQuery::refresh_catalog_partially",
level = "debug",
Expand Down Expand Up @@ -508,43 +512,56 @@ impl UserQuery {

// Convert already resolved table references to iceberg_scan function call
let setup_queries = self.update_iceberg_scan_references(&mut statement).await?;
self.query = statement.to_string();
let sql = self.query.clone();

let conn = Connection::open_in_memory().context(ex_error::DuckdbSnafu)?;
let failed = register_all_udfs(&conn, self.session.ctx.state().scalar_functions())?;
if !failed.is_empty() {
tracing::warn!(
"Some UDFs were not registered/overloaded in DuckDB: {:?}",
failed
);
}

// Set session params for session UDFs and register Embucket UDFs in Duckdb connection
self.set_duckdb_udf_params()?;
register_all_udfs(
&conn,
&mut statement,
self.session.ctx.state().scalar_functions(),
)?;

self.query = statement.to_string();
let sql = self.query.clone();

// Apply setup queries
apply_connection_setup_queries(&conn, &setup_queries)?;

if self.session.config.use_duck_db_explain
|| self
.session
.get_session_variable_bool("embucket.execution.explain_before_acceleration")
{
// Check if possible to call duckdb with this query
let explain_conn = conn.try_clone().context(ex_error::DuckdbSnafu)?;
let _explain_result = execute_duck_db_explain(explain_conn, &sql).await?;
// Run EXPLAIN before query execution to check if the query is supported
let _explain_result = execute_duck_db_explain(&conn, &sql)?;
}

let stream = query_duck_db_arrow(&conn, &sql)?;
let schema = stream.schema().clone();
let records = stream
.try_collect::<Vec<_>>()
.await
.context(ex_error::DataFusionSnafu)?;
let (records, schema) = query_duck_db_arrow(&conn, &sql)?;
Ok::<QueryResult, Error>(QueryResult::new(
records,
schema,
self.query_context.query_id,
))
}

#[instrument(
name = "UserQuery::set_duckdb_udf_params",
level = "trace",
skip(self),
err
)]
pub fn set_duckdb_udf_params(&self) -> Result<()> {
let params = HashMap::from([
(session_prop("current_database"), self.current_database()),
(session_prop("current_schema"), self.current_schema()),
(session_prop("current_version"), self.current_version()),
]);
self.session
.set_session_params(params, &self.session.ctx.session_id())
}

#[instrument(name = "UserQuery::get_catalog", level = "trace", skip(self), err)]
pub fn get_catalog(&self, name: &str) -> Result<Arc<dyn CatalogProvider>> {
self.session
Expand Down
29 changes: 29 additions & 0 deletions crates/core-executor/src/session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,35 @@ impl UserSession {
Ok(())
}

pub fn set_session_params(
&self,
properties: HashMap<String, String>,
session_id: &str,
) -> Result<()> {
let state = self.ctx.state_ref();
let mut write = state.write();
let options = write.config_mut().options_mut();

let properties = properties
.into_iter()
.map(|(name, value)| {
let prop = SessionProperty::from_str_value(
name.clone(),
value,
Some(session_id.to_string()),
);
(name, prop)
})
.collect();

let config = options.extensions.get_mut::<SessionParams>();
if let Some(cfg) = config {
cfg.set_properties(properties)
.context(ex_error::DataFusionSnafu)?;
}
Ok(())
}

#[must_use]
pub fn get_session_variable(&self, variable: &str) -> Option<String> {
let state = self.ctx.state();
Expand Down
6 changes: 3 additions & 3 deletions crates/embucket-functions/src/datetime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ pub fn register_udfs(
date_from_parts::get_udf(),
dayname::get_udf(),
Arc::new(ScalarUDF::from(LastDayFunc::new(session_params.clone()))),
Arc::new(ScalarUDF::from(ConvertTimezoneFunc::new(
session_params.clone(),
))),
monthname::get_udf(),
next_day::get_udf(),
previous_day::get_udf(),
Expand All @@ -44,8 +47,5 @@ pub fn register_udfs(
registry.register_udf(func)?;
}
date_part_extract::register_udfs(registry, session_params)?;
registry.register_udf(Arc::new(ScalarUDF::from(ConvertTimezoneFunc::new(
session_params.to_owned(),
))))?;
Ok(())
}
2 changes: 1 addition & 1 deletion crates/embucket-functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ pub fn register_udfs(
semi_structured::register_udfs(registry)?;
regexp::register_udfs(registry)?;
system::register_udfs(registry)?;
session::register_session_context_udfs(registry)?;
session::register_session_context_udfs(registry, session_params)?;
window::register_udwfs(registry)?;
Ok(())
}
Expand Down
65 changes: 65 additions & 0 deletions crates/embucket-functions/src/session/current_database.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use crate::session::session_prop;
use crate::session_params::SessionParams;
use datafusion::arrow::array::StringArray;
use datafusion::arrow::datatypes::DataType;
use datafusion::error::Result as DFResult;
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility};
use std::any::Any;
use std::sync::Arc;

/// Returns the name of the current database, which varies depending on where you call the function
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct CurrentDatabase {
signature: Signature,
session_params: Arc<SessionParams>,
}

impl Default for CurrentDatabase {
fn default() -> Self {
Self::new(Arc::new(SessionParams::default()))
}
}

impl CurrentDatabase {
#[must_use]
pub fn new(session_params: Arc<SessionParams>) -> Self {
Self {
signature: Signature::nullary(Volatility::Stable),
session_params,
}
}

#[must_use]
pub fn current_database(&self) -> String {
self.session_params
.get_property(&session_prop("current_database"))
.unwrap_or_else(|| "embucket".to_string())
}
}

impl ScalarUDFImpl for CurrentDatabase {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &'static str {
"current_database"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, _arg_types: &[DataType]) -> DFResult<DataType> {
Ok(DataType::Utf8)
}

fn invoke_with_args(&self, args: ScalarFunctionArgs) -> DFResult<ColumnarValue> {
let num_rows = args.number_rows;
let value = self.current_database();
let array = Arc::new(StringArray::from(vec![Some(value.as_str()); num_rows]));
Ok(ColumnarValue::Array(array))
}
}

crate::macros::make_udf_function!(CurrentDatabase);
Loading
Loading