diff --git a/crates/runtime/src/datafusion/execution.rs b/crates/runtime/src/datafusion/execution.rs index b5f2d566e..9fe4dbf06 100644 --- a/crates/runtime/src/datafusion/execution.rs +++ b/crates/runtime/src/datafusion/execution.rs @@ -104,12 +104,11 @@ impl SqlExecutor { warehouse_name: &str, ) -> IceBucketSQLResult> { // Update query to use custom JSON functions - let query = self.preprocess_query(query); + let query = Self::preprocess_query(query); let mut statement = self .parse_query(query.as_str()) .context(super::error::DataFusionSnafu)?; Self::postprocess_query_statement(&mut statement); - // statement = self.update_statement_references(statement, warehouse_name); // query = statement.to_string(); @@ -242,18 +241,15 @@ impl SqlExecutor { /// Panics if . #[must_use] #[allow(clippy::unwrap_used)] - #[tracing::instrument(level = "trace", skip(self), ret)] - pub fn preprocess_query(&self, query: &str) -> String { + #[tracing::instrument(level = "trace", ret)] + pub fn preprocess_query(query: &str) -> String { // Replace field[0].subfield -> json_get(json_get(field, 0), 'subfield') // TODO: This regex should be a static allocation let re = regex::Regex::new(r"(\w+.\w+)\[(\d+)][:\.](\w+)").unwrap(); - let date_add = - regex::Regex::new(r"(date|time|timestamp)(_?add|_?diff)\(\s*([a-zA-Z]+),").unwrap(); - + //date_add processing moved to `postprocess_query_statement` let mut query = re .replace_all(query, "json_get(json_get($1, $2), '$3')") .to_string(); - query = date_add.replace_all(&query, "$1$2('$3',").to_string(); let alter_iceberg_table = regex::Regex::new(r"alter\s+iceberg\s+table").unwrap(); query = alter_iceberg_table .replace_all(&query, "alter table") @@ -396,6 +392,7 @@ impl SqlExecutor { .await .context(ih_error::IcebergSnafu)?; }; + // Create new table rest_catalog .create_table( @@ -1334,8 +1331,144 @@ pub fn created_entity_response() -> Result, arrow::error::Arrow } #[cfg(test)] -mod test { - use crate::datafusion::execution::SqlExecutor; +mod tests { + use super::SqlExecutor; + use crate::datafusion::{session::SessionParams, type_planner::CustomTypePlanner}; + use datafusion::sql::parser::Statement as DFStatement; + use datafusion::sql::sqlparser::ast::visit_expressions; + use datafusion::sql::sqlparser::ast::{Expr, ObjectName}; + use datafusion::{ + execution::SessionStateBuilder, + prelude::{SessionConfig, SessionContext}, + }; + use datafusion_iceberg::planner::IcebergQueryPlanner; + use sqlparser::ast::Value; + use sqlparser::ast::{ + Function, FunctionArg, FunctionArgExpr, FunctionArgumentList, FunctionArguments, + }; + use std::ops::ControlFlow; + use std::sync::Arc; + + struct Test<'a, T> { + input: &'a str, + expected: T, + should_work: bool, + } + impl<'a, T> Test<'a, T> { + pub const fn new(input: &'a str, expected: T, should_work: bool) -> Self { + Self { + input, + expected, + should_work, + } + } + } + #[test] + #[allow( + clippy::unwrap_used, + clippy::explicit_iter_loop, + clippy::collapsible_match + )] + fn test_timestamp_keywords_postprocess() { + let state = SessionStateBuilder::new() + .with_config( + SessionConfig::new() + .with_information_schema(true) + .with_option_extension(SessionParams::default()) + .set_str("datafusion.sql_parser.dialect", "SNOWFLAKE"), + ) + .with_default_features() + .with_query_planner(Arc::new(IcebergQueryPlanner {})) + .with_type_planner(Arc::new(CustomTypePlanner {})) + .build(); + let ctx = SessionContext::new_with_state(state); + let executor = SqlExecutor::new(ctx).unwrap(); + let test = vec![ + Test::new( + "SELECT dateadd(year, 5, '2025-06-01')", + Value::SingleQuotedString("year".to_owned()), + true, + ), + Test::new( + "SELECT dateadd(\"year\", 5, '2025-06-01')", + Value::SingleQuotedString("year".to_owned()), + true, + ), + Test::new( + "SELECT dateadd('year', 5, '2025-06-01')", + Value::SingleQuotedString("year".to_owned()), + true, + ), + Test::new( + "SELECT dateadd(\"'year'\", 5, '2025-06-01')", + Value::SingleQuotedString("year".to_owned()), + false, + ), + Test::new( + "SELECT dateadd(\'year\', 5, '2025-06-01')", + Value::SingleQuotedString("year".to_owned()), + true, + ), + Test::new( + "SELECT datediff(day, 5, '2025-06-01')", + Value::SingleQuotedString("day".to_owned()), + true, + ), + Test::new( + "SELECT datediff('week', 5, '2025-06-01')", + Value::SingleQuotedString("week".to_owned()), + true, + ), + Test::new( + "SELECT datediff(nsecond, 10000000, '2025-06-01')", + Value::SingleQuotedString("nsecond".to_owned()), + true, + ), + Test::new( + "SELECT date_diff(hour, 5, '2025-06-01')", + Value::SingleQuotedString("hour".to_owned()), + true, + ), + Test::new( + "SELECT date_add(us, 100000, '2025-06-01')", + Value::SingleQuotedString("us".to_owned()), + true, + ), + ]; + for test in test.iter() { + let mut statement = executor.parse_query(test.input).unwrap(); + SqlExecutor::postprocess_query_statement(&mut statement); + if let DFStatement::Statement(statement) = statement { + visit_expressions(&statement, |expr| { + if let Expr::Function(Function { + name: ObjectName(idents), + args: FunctionArguments::List(FunctionArgumentList { args, .. }), + .. + }) = expr + { + match idents.first().unwrap().value.as_str() { + "dateadd" | "date_add" | "datediff" | "date_diff" => { + if let FunctionArg::Unnamed(FunctionArgExpr::Expr(ident)) = + args.iter().next().unwrap() + { + if let Expr::Value(found) = ident { + if test.should_work { + assert_eq!(*found, test.expected); + } else { + assert_ne!(*found, test.expected); + } + } + } + } + _ => {} + } + } + ControlFlow::<()>::Continue(()) + }); + } + } + } + use datafusion::sql::parser::DFParser; #[allow(clippy::unwrap_used)] diff --git a/crates/runtime/src/datafusion/functions/convert_timezone.rs b/crates/runtime/src/datafusion/functions/convert_timezone.rs index 0b3e0f917..a68906e18 100644 --- a/crates/runtime/src/datafusion/functions/convert_timezone.rs +++ b/crates/runtime/src/datafusion/functions/convert_timezone.rs @@ -341,11 +341,10 @@ mod tests { ); assert_eq!( result, expected, - "convert_timezone created wrong value for {}", - source_timestamp_tz_value - ) + "convert_timezone created wrong value for {source_timestamp_tz_value}" + ); } - _ => panic!("Conversion of {} failed", source_timestamp_tz), + _ => panic!("Conversion of {source_timestamp_tz} failed"), } } #[test] @@ -371,11 +370,10 @@ mod tests { ); assert_eq!( result, expected, - "convert_timezone created wrong value for {}", - source_timestamp_tz_value - ) + "convert_timezone created wrong value for {source_timestamp_tz_value}" + ); } - _ => panic!("Conversion of {} failed", source_timestamp_tz_value), + _ => panic!("Conversion of {source_timestamp_tz_value} failed"), } } #[test] @@ -403,11 +401,10 @@ mod tests { ); assert_ne!( result, expected, - "convert_timezone created wrong value for {}", - source_timestamp_tz_value - ) + "convert_timezone created wrong value for {source_timestamp_tz_value}" + ); } - _ => panic!("Conversion of {} failed", source_timestamp_tz_value), + _ => panic!("Conversion of {source_timestamp_tz_value} failed"), } } } diff --git a/crates/runtime/src/datafusion/functions/date_add.rs b/crates/runtime/src/datafusion/functions/date_add.rs index 2e747c060..856efdb1d 100644 --- a/crates/runtime/src/datafusion/functions/date_add.rs +++ b/crates/runtime/src/datafusion/functions/date_add.rs @@ -235,7 +235,7 @@ mod tests { )), ]; let fn_args = ScalarFunctionArgs { - args: args, + args, number_rows: 0, return_type: &arrow_schema::DataType::Timestamp( arrow_schema::TimeUnit::Microsecond, @@ -248,7 +248,7 @@ mod tests { Some(1736600400000000i64), Some(Arc::from(String::from("+00").into_boxed_str())), ); - assert_eq!(&result, &expected, "date_add created a wrong value") + assert_eq!(&result, &expected, "date_add created a wrong value"); } _ => panic!("Conversion failed"), } @@ -268,7 +268,7 @@ mod tests { ), ]; let fn_args = ScalarFunctionArgs { - args: args, + args, number_rows: 0, return_type: &arrow_schema::DataType::Timestamp( arrow_schema::TimeUnit::Microsecond, @@ -283,7 +283,7 @@ mod tests { ) .to_array() .unwrap(); - assert_eq!(&result, &expected, "date_add created a wrong value") + assert_eq!(&result, &expected, "date_add created a wrong value"); } _ => panic!("Conversion failed"), } @@ -303,7 +303,7 @@ mod tests { ), ]; let fn_args = ScalarFunctionArgs { - args: args, + args, number_rows: 0, return_type: &arrow_schema::DataType::Timestamp( arrow_schema::TimeUnit::Microsecond, @@ -318,7 +318,7 @@ mod tests { )) .to_array(2) .unwrap(); - assert_eq!(&result, &expected, "date_add created a wrong value") + assert_eq!(&result, &expected, "date_add created a wrong value"); } _ => panic!("Conversion failed"), } diff --git a/crates/runtime/src/datafusion/functions/mod.rs b/crates/runtime/src/datafusion/functions/mod.rs index b84c117b2..0e275a9f4 100644 --- a/crates/runtime/src/datafusion/functions/mod.rs +++ b/crates/runtime/src/datafusion/functions/mod.rs @@ -18,8 +18,10 @@ use std::sync::Arc; use datafusion::{common::Result, execution::FunctionRegistry, logical_expr::ScalarUDF}; -use sqlparser::ast::Value::SingleQuotedString; -use sqlparser::ast::{Expr, Function, FunctionArg, FunctionArgExpr, FunctionArguments}; +use sqlparser::ast::Value::{self, SingleQuotedString}; +use sqlparser::ast::{ + Expr, Function, FunctionArg, FunctionArgExpr, FunctionArgumentList, FunctionArguments, Ident, +}; mod convert_timezone; mod date_add; @@ -99,4 +101,18 @@ pub fn visit_functions_expressions(func: &mut Function) { _ => func_name, }; func.name = sqlparser::ast::ObjectName(vec![sqlparser::ast::Ident::new(name)]); + if let FunctionArguments::List(FunctionArgumentList { args, .. }) = &mut func.args { + match func_name { + "dateadd" | "date_add" | "datediff" | "date_diff" => { + if let Some(FunctionArg::Unnamed(FunctionArgExpr::Expr(ident))) = + args.iter_mut().next() + { + if let Expr::Identifier(Ident { value, .. }) = ident { + *ident = Expr::Value(Value::SingleQuotedString(value.clone())); + } + } + } + _ => {} + } + } } diff --git a/crates/runtime/src/tests/queries.rs b/crates/runtime/src/tests/queries.rs index b5461bd89..de3fa5a59 100644 --- a/crates/runtime/src/tests/queries.rs +++ b/crates/runtime/src/tests/queries.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -// use crate::tests::utils::macros::test_query; +use crate::tests::utils::macros::test_query; +test_query!(select_date_add_diff, "SELECT dateadd(day, 5, '2025-06-01')"); // // SELECT // test_query!(select_star, "SELECT * FROM employee_table"); // test_query!(select_ilike, "SELECT * ILIKE '%id%' FROM employee_table;"); diff --git a/crates/runtime/src/tests/snapshots/query_select_date_add_diff.snap b/crates/runtime/src/tests/snapshots/query_select_date_add_diff.snap new file mode 100644 index 000000000..2125cf2bd --- /dev/null +++ b/crates/runtime/src/tests/snapshots/query_select_date_add_diff.snap @@ -0,0 +1,131 @@ +--- +source: crates/runtime/src/tests/queries.rs +description: "\"SELECT dateadd(day, 5, '2025-06-01')\"" +--- +( + Statement( + Query( + Query { + with: None, + body: Select( + Select { + select_token: TokenWithSpan { + token: Word( + Word { + value: "SELECT", + quote_style: None, + keyword: SELECT, + }, + ), + span: Span(Location(1,1)..Location(1,7)), + }, + distinct: None, + top: None, + top_before_distinct: false, + projection: [ + UnnamedExpr( + Function( + Function { + name: ObjectName( + [ + Ident { + value: "dateadd", + quote_style: None, + span: Span(Location(1,8)..Location(1,15)), + }, + ], + ), + uses_odbc_syntax: false, + parameters: None, + args: List( + FunctionArgumentList { + duplicate_treatment: None, + args: [ + Unnamed( + Expr( + Value( + SingleQuotedString( + "day", + ), + ), + ), + ), + Unnamed( + Expr( + Value( + Number( + "5", + false, + ), + ), + ), + ), + Unnamed( + Expr( + Value( + SingleQuotedString( + "2025-06-01", + ), + ), + ), + ), + ], + clauses: [], + }, + ), + filter: None, + null_treatment: None, + over: None, + within_group: [], + }, + ), + ), + ], + into: None, + from: [], + lateral_views: [], + prewhere: None, + selection: None, + group_by: Expressions( + [], + [], + ), + cluster_by: [], + distribute_by: [], + sort_by: [], + having: None, + named_window: [], + qualify: None, + window_before_qualify: false, + value_table_mode: None, + connect_by: None, + }, + ), + order_by: None, + limit: None, + limit_by: [], + offset: None, + fetch: None, + locks: [], + for_clause: None, + settings: None, + format_clause: None, + }, + ), + ), + Ok( + [ + "Projection: dateadd(Utf8(\"day\"), Int64(5), Utf8(\"2025-06-01\"))", + " EmptyRelation", + ], + ), + Ok( + [ + "+--------------------------------------------------+", + "| dateadd(Utf8(\"day\"),Int64(5),Utf8(\"2025-06-01\")) |", + "+--------------------------------------------------+", + "| 2025-06-06T00:00:00 |", + "+--------------------------------------------------+", + ], + ), +) diff --git a/crates/runtime/src/tests/utils.rs b/crates/runtime/src/tests/utils.rs index 2ee345615..2d3734a67 100644 --- a/crates/runtime/src/tests/utils.rs +++ b/crates/runtime/src/tests/utils.rs @@ -15,65 +15,68 @@ // specific language governing permissions and limitations // under the License. -// use crate::datafusion::functions::register_udfs; -// use datafusion::prelude::{SessionConfig, SessionContext}; +use crate::datafusion::functions::register_udfs; +use datafusion::prelude::{SessionConfig, SessionContext}; -// static TABLE_SETUP: &str = include_str!(r"./queries/table_setup.sql"); +static TABLE_SETUP: &str = include_str!(r"./queries/table_setup.sql"); -// #[allow(clippy::unwrap_used)] -// pub async fn create_df_session() -> SessionContext { -// let mut config = SessionConfig::new(); -// config.options_mut().catalog.information_schema = true; -// let mut ctx = SessionContext::new_with_config(config); +#[allow(clippy::unwrap_used)] +pub async fn create_df_session() -> SessionContext { + let mut config = SessionConfig::new(); + config.options_mut().catalog.information_schema = true; + let mut ctx = SessionContext::new_with_config(config); -// register_udfs(&mut ctx).unwrap(); + register_udfs(&mut ctx).unwrap(); -// for query in TABLE_SETUP.split(';') { -// if !query.is_empty() { -// dbg!("Running query: ", query); -// ctx.sql(query).await.unwrap().collect().await.unwrap(); -// } -// } -// ctx -// } + for query in TABLE_SETUP.split(';') { + if !query.is_empty() { + dbg!("Running query: ", query); + ctx.sql(query).await.unwrap().collect().await.unwrap(); + } + } + ctx +} -// pub mod macros { -// macro_rules! test_query { -// ($test_fn_name:ident, $query:expr) => { -// paste::paste! { -// #[tokio::test] -// async fn [< query_ $test_fn_name >]() { -// let ctx = crate::tests::utils::create_df_session().await; -// let statement = ctx.state().sql_to_statement($query, "snowflake"); +pub mod macros { + macro_rules! test_query { + ($test_fn_name:ident, $query:expr) => { + paste::paste! { + #[tokio::test] + async fn [< query_ $test_fn_name >]() { + let ctx = crate::tests::utils::create_df_session().await; -// let plan = ctx.state().create_logical_plan($query) -// .await; + let query = crate::datafusion::execution::SqlExecutor::preprocess_query($query); + let mut statement = ctx.state().sql_to_statement(query.as_str(), "snowflake") + .unwrap(); + crate::datafusion::execution::SqlExecutor::postprocess_query_statement(&mut statement); + let plan = ctx.state().statement_to_plan(statement.clone()) + .await; + //TODO: add our plan processing also + let df = match &plan { + Ok(plan) => { + match ctx.execute_logical_plan(plan.clone()).await { + Ok(df) => { + let record_batches = df.collect().await.unwrap(); + Ok(datafusion::arrow::util::pretty::pretty_format_batches(&record_batches).unwrap().to_string()) + }, + Err(e) => Err(e) + } + }, + _ => Err(datafusion::error::DataFusionError::Execution("Failed to create logical plan".to_string())) + }; + insta::with_settings!({ + description => stringify!($query), + omit_expression => true, + prepend_module_to_snapshot => false + }, { + let plan = plan.map(|plan| plan.to_string().split("\n").map(|s| s.to_string()).collect::>()); + let df = df.map(|df| df.split("\n").map(|s| s.to_string()).collect::>()); + insta::assert_debug_snapshot!((statement, plan, df)); + }) + } + } + } + } -// let df = match &plan { -// Ok(plan) => { -// match ctx.execute_logical_plan(plan.clone()).await { -// Ok(df) => { -// let record_batches = df.collect().await.unwrap(); -// Ok(datafusion::arrow::util::pretty::pretty_format_batches(&record_batches).unwrap().to_string()) -// }, -// Err(e) => Err(e) -// } -// }, -// _ => Err(datafusion::error::DataFusionError::Execution("Failed to create logical plan".to_string())) -// }; -// insta::with_settings!({ -// description => stringify!($query), -// omit_expression => true, -// prepend_module_to_snapshot => false -// }, { -// let plan = plan.map(|plan| plan.to_string().split("\n").map(|s| s.to_string()).collect::>()); -// let df = df.map(|df| df.split("\n").map(|s| s.to_string()).collect::>()); -// insta::assert_debug_snapshot!((statement, plan, df)); -// }) -// } -// } -// } -// } - -// pub(crate) use test_query; -// } + pub(crate) use test_query; +}