diff --git a/Cargo.toml b/Cargo.toml index c463c8521..eca9116c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,8 +36,7 @@ snafu = { version = "0.8.5", features = ["futures"] } tracing = { version = "0.1" } [patch.crates-io] -datafusion = { git="https://github.com/Embucket/datafusion.git", rev = "d176c5872a93b0c5dfdd1d2bc717ad22739e9c18" } - +datafusion = { git="https://github.com/Embucket/datafusion.git", rev = "4930757e1108fbe33704b29d1aa222a3d6214584" } [workspace.lints.clippy] all={ level="deny", priority=-1 } diff --git a/crates/control_plane/Cargo.toml b/crates/control_plane/Cargo.toml index 00dab2b80..5ae28f20a 100644 --- a/crates/control_plane/Cargo.toml +++ b/crates/control_plane/Cargo.toml @@ -18,13 +18,13 @@ flatbuffers = { version = "24.3.25" } #iceberg-rest-catalog = { git = "https://github.com/JanKaul/iceberg-rust.git", rev = "836f11f" } #datafusion_iceberg = { git = "https://github.com/JanKaul/iceberg-rust.git", rev = "836f11f" } -datafusion = { git="https://github.com/Embucket/datafusion.git", rev = "d176c5872a93b0c5dfdd1d2bc717ad22739e9c18" } -datafusion-common = { git="https://github.com/Embucket/datafusion.git", rev = "d176c5872a93b0c5dfdd1d2bc717ad22739e9c18" } -datafusion-expr = { git="https://github.com/Embucket/datafusion.git", rev = "d176c5872a93b0c5dfdd1d2bc717ad22739e9c18" } +datafusion = { git="https://github.com/Embucket/datafusion.git", rev = "4930757e1108fbe33704b29d1aa222a3d6214584" } +datafusion-common = { git="https://github.com/Embucket/datafusion.git", rev = "4930757e1108fbe33704b29d1aa222a3d6214584" } +datafusion-expr = { git="https://github.com/Embucket/datafusion.git", rev = "4930757e1108fbe33704b29d1aa222a3d6214584" } -iceberg-rust = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "4114c25c46d0c7ad272031e61ece1e62a892ddfc" } -iceberg-rest-catalog = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "4114c25c46d0c7ad272031e61ece1e62a892ddfc" } -datafusion_iceberg = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "4114c25c46d0c7ad272031e61ece1e62a892ddfc" } +iceberg-rust = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "7f4c767e9f8e4398a01a37190c30be3864066e34" } +iceberg-rest-catalog = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "7f4c767e9f8e4398a01a37190c30be3864066e34" } +datafusion_iceberg = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "7f4c767e9f8e4398a01a37190c30be3864066e34" } arrow = { version = "53" } arrow-json = { version = "53" } diff --git a/crates/control_plane/src/service.rs b/crates/control_plane/src/service.rs index 73ca136f5..10530f7b5 100644 --- a/crates/control_plane/src/service.rs +++ b/crates/control_plane/src/service.rs @@ -18,6 +18,7 @@ use iceberg_rest_catalog::catalog::RestCatalog; use object_store::path::Path; use object_store::{ObjectStore, PutPayload}; use runtime::datafusion::execution::SqlExecutor; +use runtime::datafusion::type_planner::CustomTypePlanner; use rusoto_core::{HttpClient, Region}; use rusoto_credential::StaticProvider; use rusoto_s3::{GetBucketAclRequest, S3Client, S3}; @@ -248,6 +249,7 @@ impl ControlService for ControlServiceImpl { ) .with_default_features() .with_query_planner(Arc::new(IcebergQueryPlanner {})) + .with_type_planner(Arc::new(CustomTypePlanner {})) .build(); let ctx = SessionContext::new_with_state(state); @@ -407,7 +409,12 @@ impl ControlService for ControlServiceImpl { object_store_builder, ); let catalog = IcebergCatalog::new(Arc::new(rest_client), None).await?; - let ctx = SessionContext::new(); + let state = SessionStateBuilder::new() + .with_default_features() + .with_query_planner(Arc::new(IcebergQueryPlanner {})) + .build(); + + let ctx = SessionContext::new_with_state(state); ctx.register_catalog(warehouse_name.clone(), Arc::new(catalog)); // Register CSV file as a table diff --git a/crates/nexus/src/http/dbt/schemas.rs b/crates/nexus/src/http/dbt/schemas.rs index 4dd02d520..216350df2 100644 --- a/crates/nexus/src/http/dbt/schemas.rs +++ b/crates/nexus/src/http/dbt/schemas.rs @@ -16,7 +16,7 @@ pub struct LoginRequestQuery { #[serde(rename = "warehouse")] pub warehouse: String, #[serde(rename = "roleName")] - pub role_name: String, + pub role_name: Option, } #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)] diff --git a/crates/runtime/Cargo.toml b/crates/runtime/Cargo.toml index 2320be622..68f65b6fa 100644 --- a/crates/runtime/Cargo.toml +++ b/crates/runtime/Cargo.toml @@ -15,13 +15,13 @@ serde_json = { workspace = true } object_store = { workspace = true } tracing = { workspace = true} -datafusion = { git="https://github.com/Embucket/datafusion.git", rev = "d176c5872a93b0c5dfdd1d2bc717ad22739e9c18" } -datafusion-common = { git="https://github.com/Embucket/datafusion.git", rev = "d176c5872a93b0c5dfdd1d2bc717ad22739e9c18" } -datafusion-expr = { git="https://github.com/Embucket/datafusion.git", rev = "d176c5872a93b0c5dfdd1d2bc717ad22739e9c18" } +datafusion = { git="https://github.com/Embucket/datafusion.git", rev = "4930757e1108fbe33704b29d1aa222a3d6214584" } +datafusion-common = { git="https://github.com/Embucket/datafusion.git", rev = "4930757e1108fbe33704b29d1aa222a3d6214584" } +datafusion-expr = { git="https://github.com/Embucket/datafusion.git", rev = "4930757e1108fbe33704b29d1aa222a3d6214584" } -iceberg-rust = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "4114c25c46d0c7ad272031e61ece1e62a892ddfc" } -iceberg-rest-catalog = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "4114c25c46d0c7ad272031e61ece1e62a892ddfc" } -datafusion_iceberg = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "4114c25c46d0c7ad272031e61ece1e62a892ddfc" } +iceberg-rust = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "7f4c767e9f8e4398a01a37190c30be3864066e34" } +iceberg-rest-catalog = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "7f4c767e9f8e4398a01a37190c30be3864066e34" } +datafusion_iceberg = { git = "https://github.com/Embucket/iceberg-rust.git", rev = "7f4c767e9f8e4398a01a37190c30be3864066e34" } arrow = { version = "53" } arrow-json = { version = "53" } diff --git a/crates/runtime/src/datafusion/context.rs b/crates/runtime/src/datafusion/context.rs deleted file mode 100644 index 194b60d5d..000000000 --- a/crates/runtime/src/datafusion/context.rs +++ /dev/null @@ -1,134 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow::datatypes::{DataType, SchemaRef}; -use std::{collections::HashMap, sync::Arc}; - -use datafusion::common::file_options::file_type::FileType; -use datafusion::common::plan_datafusion_err; -use datafusion::common::Result; -use datafusion::config::ConfigOptions; -use datafusion::datasource::default_table_source::provider_as_source; -use datafusion::execution::session_state::SessionState; -use datafusion::logical_expr::var_provider::is_system_variables; -use datafusion::logical_expr::WindowUDF; -use datafusion::logical_expr::{AggregateUDF, ScalarUDF, TableSource}; -use datafusion::prelude::*; -use datafusion::sql::planner::ContextProvider; -use datafusion::sql::TableReference; -use datafusion::variable::VarType; - -pub struct CustomContextProvider<'a> { - pub(crate) state: &'a SessionState, - pub(crate) tables: HashMap>, -} - -impl ContextProvider for CustomContextProvider<'_> { - fn get_table_source(&self, name: TableReference) -> Result> { - let catalog = self.state.config_options().catalog.clone(); - let name = name.resolve(&catalog.default_catalog, &catalog.default_schema); - // println!("Table name: {:?}, to_string {}", name, name.to_string()); - // println!("Tables: {:?}", self.tables.keys()); - self.tables - .get(&name.to_string()) - .cloned() - .ok_or_else(|| plan_datafusion_err!("table '{name}' not found")) - } - - fn get_file_type(&self, ext: &str) -> Result> { - self.state - .get_file_format_factory(ext) - .ok_or(plan_datafusion_err!( - "There is no registered file format with ext {ext}" - )) - .map(|file_type| { - datafusion::datasource::file_format::format_as_file_type(file_type.clone()) - }) - } - - fn get_table_function_source( - &self, - name: &str, - args: Vec, - ) -> Result> { - let tbl_func = self - .state - .table_functions() - .get(name) - .cloned() - .ok_or_else(|| plan_datafusion_err!("table function '{name}' not found"))?; - let provider = tbl_func.create_table_provider(&args)?; - - Ok(provider_as_source(provider)) - } - - /// Create a new CTE work table for a recursive CTE logical plan - /// This table will be used in conjunction with a Worktable physical plan - /// to read and write each iteration of a recursive CTE - fn create_cte_work_table(&self, name: &str, schema: SchemaRef) -> Result> { - let table = Arc::new(datafusion::datasource::cte_worktable::CteWorkTable::new( - name, schema, - )); - Ok(provider_as_source(table)) - } - - fn get_function_meta(&self, name: &str) -> Option> { - self.state.scalar_functions().get(name).cloned() - } - - fn get_aggregate_meta(&self, name: &str) -> Option> { - self.state.aggregate_functions().get(name).cloned() - } - - fn get_window_meta(&self, name: &str) -> Option> { - self.state.window_functions().get(name).cloned() - } - - fn get_variable_type(&self, variable_names: &[String]) -> Option { - if variable_names.is_empty() { - return None; - } - - let provider_type = if is_system_variables(variable_names) { - VarType::System - } else { - VarType::UserDefined - }; - - self.state - .execution_props() - .var_providers - .as_ref() - .and_then(|provider| provider.get(&provider_type)?.get_type(variable_names)) - } - - fn options(&self) -> &ConfigOptions { - self.state.config_options() - } - - fn udf_names(&self) -> Vec { - self.state.scalar_functions().keys().cloned().collect() - } - - fn udaf_names(&self) -> Vec { - self.state.aggregate_functions().keys().cloned().collect() - } - - fn udwf_names(&self) -> Vec { - self.state.window_functions().keys().cloned().collect() - } -} diff --git a/crates/runtime/src/datafusion/execution.rs b/crates/runtime/src/datafusion/execution.rs index bb73ea860..c772a9353 100644 --- a/crates/runtime/src/datafusion/execution.rs +++ b/crates/runtime/src/datafusion/execution.rs @@ -2,7 +2,6 @@ #![allow(clippy::missing_panics_doc)] use super::error::{self as ih_error, IcehutSQLError, IcehutSQLResult}; -use crate::datafusion::context::CustomContextProvider; use crate::datafusion::functions::register_udfs; use crate::datafusion::planner::ExtendedSqlToRel; use arrow::array::{RecordBatch, UInt64Array}; @@ -10,14 +9,16 @@ use arrow::datatypes::{DataType, Field, Schema as ArrowSchema}; use datafusion::common::tree_node::{TransformedResult, TreeNode}; use datafusion::datasource::default_table_source::provider_as_source; use datafusion::execution::context::SessionContext; +use datafusion::execution::session_state::SessionContextProvider; use datafusion::logical_expr::sqlparser::ast::Insert; use datafusion::logical_expr::LogicalPlan; use datafusion::sql::parser::{CreateExternalTable, DFParser, Statement as DFStatement}; +use datafusion::sql::planner::IdentNormalizer; use datafusion::sql::sqlparser::ast::{ CreateTable as CreateTableStatement, Expr, Ident, ObjectName, Query, SchemaName, Statement, TableFactor, TableWithJoins, }; -use datafusion_common::DataFusionError; +use datafusion_common::{DataFusionError, TableReference}; use datafusion_functions_json::register_all; use datafusion_iceberg::catalog::catalog::IcebergCatalog; use datafusion_iceberg::planner::iceberg_transform; @@ -40,13 +41,18 @@ use std::sync::Arc; pub struct SqlExecutor { ctx: SessionContext, + ident_normalizer: IdentNormalizer, } impl SqlExecutor { pub fn new(mut ctx: SessionContext) -> IcehutSQLResult { register_udfs(&mut ctx).context(ih_error::RegisterUDFSnafu)?; register_all(&mut ctx).context(ih_error::RegisterUDFSnafu)?; - Ok(Self { ctx }) + let enable_ident_normalization = ctx.enable_ident_normalization(); + Ok(Self { + ctx, + ident_normalizer: IdentNormalizer::new(enable_ident_normalization), + }) } #[tracing::instrument(level = "debug", skip(self), err, ret(level = tracing::Level::TRACE))] @@ -233,12 +239,13 @@ impl SqlExecutor { }, )?; let rest_catalog = iceberg_catalog.catalog(); + let new_table_name = self.ident_normalizer.normalize(new_table_name.clone()); let new_table_ident = Identifier::new( &new_table_db .iter() - .map(|v| v.value.clone()) + .map(|v| self.ident_normalizer.normalize(v.clone())) .collect::>(), - &new_table_name.value, + &new_table_name.clone(), ); if matches!( rest_catalog.tabular_exists(&new_table_ident).await, @@ -255,7 +262,7 @@ impl SqlExecutor { .create_table( new_table_ident.clone(), CreateTableCatalog { - name: new_table_name.value.clone(), + name: new_table_name.clone(), location, schema, partition_spec: None, @@ -448,7 +455,11 @@ impl SqlExecutor { }, )?; let rest_catalog = iceberg_catalog.catalog(); - let namespace_vec: Vec = name.0.iter().map(|ident| ident.value.clone()).collect(); + let namespace_vec: Vec = name + .0 + .iter() + .map(|ident| self.ident_normalizer.normalize(ident.clone())) + .collect(); let single_layer_namespace = vec![namespace_vec.join(".")]; let namespace = @@ -479,17 +490,18 @@ impl SqlExecutor { //println!("modified query: {:?}", statement.to_string()); if let DFStatement::Statement(s) = statement.clone() { - let mut ctx_provider = CustomContextProvider { + let mut ctx_provider = SessionContextProvider { state: &state, tables: HashMap::new(), }; + let references = state .resolve_table_references(&statement) .context(super::error::DataFusionSnafu)?; //println!("References: {:?}", references); for reference in references { let resolved = state.resolve_table_ref(reference); - if let Entry::Vacant(v) = ctx_provider.tables.entry(resolved.to_string()) { + if let Entry::Vacant(v) = ctx_provider.tables.entry(resolved.clone()) { if let Ok(schema) = state.schema_for_ref(resolved.clone()) { if let Some(table) = schema .table(&resolved.table) @@ -516,15 +528,20 @@ impl SqlExecutor { .ok_or(IcehutSQLError::TableProviderNotFound { table_name: table.clone(), })?; - ctx_provider.tables.insert( - format!("{catalog}.{schema}.{table}"), - provider_as_source(table_source), - ); + let resolved = state.resolve_table_ref(TableReference::full( + catalog.to_string(), + schema.to_string(), + table, + )); + ctx_provider + .tables + .insert(resolved, provider_as_source(table_source)); } } } - let planner = ExtendedSqlToRel::new(&ctx_provider); + let planner = + ExtendedSqlToRel::new(&ctx_provider, self.ctx.state().get_parser_options()); planner .sql_statement_to_plan(*s) .context(super::error::DataFusionSnafu) diff --git a/crates/runtime/src/datafusion/mod.rs b/crates/runtime/src/datafusion/mod.rs index dafb2b786..4fded0a3e 100644 --- a/crates/runtime/src/datafusion/mod.rs +++ b/crates/runtime/src/datafusion/mod.rs @@ -1,7 +1,7 @@ //pub mod analyzer; -pub mod context; pub mod error; pub mod functions; pub mod planner; //pub mod session; pub mod execution; +pub mod type_planner; diff --git a/crates/runtime/src/datafusion/planner.rs b/crates/runtime/src/datafusion/planner.rs index b986cce28..db46ca319 100644 --- a/crates/runtime/src/datafusion/planner.rs +++ b/crates/runtime/src/datafusion/planner.rs @@ -15,32 +15,23 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{ - DataType, Field, Fields, IntervalUnit, Schema, TimeUnit, DECIMAL128_MAX_PRECISION, - DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, -}; -use datafusion::catalog_common::TableReference; -use datafusion::common::error::_plan_err; -use datafusion::common::{ - not_impl_err, plan_datafusion_err, plan_err, Constraint, Constraints, DFSchema, DFSchemaRef, - ToDFSchema, -}; -use datafusion::common::{DataFusionError, Result, SchemaError}; -use datafusion::logical_expr::sqlparser::ast; -use datafusion::logical_expr::sqlparser::ast::{ - ArrayElemTypeDef, ColumnDef, ExactNumberInfo, Ident, ObjectName, TableConstraint, -}; +use arrow::datatypes::{Field, Schema}; +use datafusion::common::Result; +use datafusion::common::{plan_err, ToDFSchema}; +use datafusion::logical_expr::sqlparser::ast::{Ident, ObjectName}; use datafusion::logical_expr::{CreateMemoryTable, DdlStatement, EmptyRelation, LogicalPlan}; -use datafusion::prelude::*; use datafusion::sql::parser::{DFParser, Statement as DFStatement}; use datafusion::sql::planner::{ - object_name_to_table_reference, ContextProvider, IdentNormalizer, PlannerContext, SqlToRel, + object_name_to_table_reference, ContextProvider, IdentNormalizer, ParserOptions, + PlannerContext, SqlToRel, }; use datafusion::sql::sqlparser::ast::{ ColumnDef as SQLColumnDef, ColumnOption, CreateTable as CreateTableStatement, - DataType as SQLDataType, Statement, TimezoneInfo, + DataType as SQLDataType, Statement, }; -use datafusion_common::SchemaReference; +use datafusion::sql::statement::{calc_inline_constraints_from_columns, object_name_to_string}; +use datafusion::sql::utils::normalize_ident; +use datafusion_common::{DFSchema, DFSchemaRef, SchemaReference, TableReference}; use datafusion_expr::DropCatalogSchema; use sqlparser::ast::ObjectType; use std::sync::Arc; @@ -51,6 +42,7 @@ where { inner: SqlToRel<'a, S>, // The wrapped type provider: &'a S, + options: ParserOptions, ident_normalizer: IdentNormalizer, } @@ -59,11 +51,14 @@ where S: ContextProvider, { /// Create a new instance of `ExtendedSqlToRel` - pub fn new(provider: &'a S) -> Self { + pub fn new(provider: &'a S, options: ParserOptions) -> Self { + let ident_normalize = options.enable_ident_normalization; + Self { inner: SqlToRel::new(provider), provider, - ident_normalizer: IdentNormalizer::default(), + options, + ident_normalizer: IdentNormalizer::new(ident_normalize), } } @@ -109,7 +104,10 @@ where } => match object_type { ObjectType::Database => { #[allow(clippy::unwrap_used)] - let name = object_name_to_table_reference(names.pop().unwrap(), true)?; + let name = object_name_to_table_reference( + names.pop().unwrap(), + self.options.enable_ident_normalization, + )?; let schema_name = match name { TableReference::Bare { table } => { Ok(SchemaReference::Bare { schema: table }) @@ -150,7 +148,9 @@ where let inline_constraints = calc_inline_constraints_from_columns(&columns); all_constraints.extend(inline_constraints); // Build column default values - let column_defaults = self.build_column_defaults(&columns, planner_context)?; + let column_defaults = self + .inner + .build_column_defaults(&columns, planner_context)?; // println!("column_defaults: {:?}", column_defaults); // println!("statement 11: {:?}", statement); let has_columns = !columns.is_empty(); @@ -167,7 +167,7 @@ where schema, }; let plan = LogicalPlan::EmptyRelation(plan); - let constraints = Self::new_constraint_from_table_constraints( + let constraints = SqlToRel::::new_constraint_from_table_constraints( &all_constraints, plan.schema(), )?; @@ -188,48 +188,11 @@ where } } - /// Returns a vector of (`column_name`, `default_expr`) pairs - pub fn build_column_defaults( - &self, - columns: &Vec, - planner_context: &mut PlannerContext, - ) -> Result> { - let mut column_defaults = vec![]; - // Default expressions are restricted, column references are not allowed - let empty_schema = DFSchema::empty(); - let error_desc = |e: DataFusionError| match e { - DataFusionError::SchemaError(SchemaError::FieldNotFound { .. }, _) => { - plan_datafusion_err!( - "Column reference is not allowed in the DEFAULT expression : {}", - e - ) - } - _ => e, - }; - - for column in columns { - if let Some(default_sql_expr) = column.options.iter().find_map(|o| match &o.option { - ColumnOption::Default(expr) => Some(expr), - _ => None, - }) { - let default_expr = self - .inner - .sql_to_expr(default_sql_expr.clone(), &empty_schema, planner_context) - .map_err(error_desc)?; - column_defaults.push(( - self.ident_normalizer.normalize(column.name.clone()), - default_expr, - )); - } - } - Ok(column_defaults) - } - pub fn build_schema(&self, columns: Vec) -> Result { let mut fields = Vec::with_capacity(columns.len()); for column in columns { - let data_type = self.convert_data_type(&column.data_type)?; + let data_type = self.inner.convert_data_type(&column.data_type)?; let not_nullable = column .options .iter() @@ -247,218 +210,6 @@ where Ok(Schema::new(fields)) } - pub fn convert_data_type(&self, sql_type: &SQLDataType) -> Result { - match sql_type { - SQLDataType::Array( - ArrayElemTypeDef::AngleBracket(inner_sql_type) - | ArrayElemTypeDef::SquareBracket(inner_sql_type, _), - ) => { - // Arrays may be multi-dimensional. - let inner_data_type = self.convert_data_type(inner_sql_type)?; - Ok(DataType::new_list(inner_data_type, true)) - } - SQLDataType::Array(ArrayElemTypeDef::None) => { - not_impl_err!("Arrays with unspecified type is not supported") - } - other => self.convert_simple_data_type(other), - } - } - - #[allow(clippy::too_many_lines, clippy::match_same_arms)] - fn convert_simple_data_type(&self, sql_type: &SQLDataType) -> Result { - match sql_type { - SQLDataType::Boolean | SQLDataType::Bool => Ok(DataType::Boolean), - SQLDataType::TinyInt(_) => Ok(DataType::Int8), - SQLDataType::SmallInt(_) | SQLDataType::Int2(_)| SQLDataType::Int16 => Ok(DataType::Int16), - SQLDataType::Int(_) - | SQLDataType::Integer(_) - | SQLDataType::Int4(_) - | SQLDataType::Int32 => Ok(DataType::Int32) , - SQLDataType::BigInt(_) | SQLDataType::Int8(_) | SQLDataType::Int64 => Ok(DataType::Int64), - SQLDataType::UnsignedTinyInt(_) | SQLDataType::UInt8 => Ok(DataType::UInt8), - SQLDataType::UnsignedSmallInt(_) - | SQLDataType::UnsignedInt2(_) - | SQLDataType::UInt16 => Ok(DataType::UInt16), - SQLDataType::UnsignedInt(_) - | SQLDataType::UnsignedInteger(_) - | SQLDataType::UnsignedInt4(_) - | SQLDataType::UInt32 => Ok(DataType::UInt32), - SQLDataType::Varchar(length) => match (length, true) { - (Some(_), false) => plan_err!( - "does not support Varchar with length, please set `support_varchar_with_length` to be true" - ), - _ => Ok(DataType::Utf8), - }, - SQLDataType::Blob(_) => Ok(DataType::Binary), - SQLDataType::UnsignedBigInt(_) - | SQLDataType::UnsignedInt8(_) - | SQLDataType::UInt64 => Ok(DataType::UInt64), - SQLDataType::Real - | SQLDataType::Float4 - | SQLDataType::Float(_) - | SQLDataType::Float32=> Ok(DataType::Float32), - SQLDataType::Double - | SQLDataType::DoublePrecision - | SQLDataType::Float8 - | SQLDataType::Float64 => Ok(DataType::Float64), - SQLDataType::Char(_) | SQLDataType::Text | SQLDataType::String(_) => Ok(DataType::Utf8), - SQLDataType::Timestamp(precision, tz_info) => { - let tz = if matches!(tz_info, TimezoneInfo::Tz) - || matches!(tz_info, TimezoneInfo::WithTimeZone) - { - // Timestamp With Time Zone - // INPUT : [SQLDataType] TimestampTz + [RuntimeConfig] Time Zone - // OUTPUT: [ArrowDataType] Timestamp - self.provider.options().execution.time_zone.clone() - } else { - // Timestamp Without Time zone - None - }; - let precision = match precision { - Some(0) => TimeUnit::Second, - Some(3) => TimeUnit::Millisecond, - Some(6) => TimeUnit::Microsecond, - None | Some(9) => TimeUnit::Nanosecond, - _ => unreachable!(), - }; - Ok(DataType::Timestamp(precision, tz.map(Into::into))) - } - SQLDataType::Date => Ok(DataType::Date32), - SQLDataType::Time(None, tz_info) => { - if matches!(tz_info, TimezoneInfo::None) - || matches!(tz_info, TimezoneInfo::WithoutTimeZone) - { - Ok(DataType::Time64(TimeUnit::Nanosecond)) - } else { - // We dont support TIMETZ and TIME WITH TIME ZONE for now - not_impl_err!("Unsupported SQL type {sql_type:?}") - } - } - SQLDataType::Numeric(exact_number_info) | SQLDataType::Decimal(exact_number_info) => { - let (precision, scale) = match *exact_number_info { - ExactNumberInfo::None => (None, None), - ExactNumberInfo::Precision(precision) => (Some(precision), None), - ExactNumberInfo::PrecisionAndScale(precision, scale) => { - (Some(precision), Some(scale)) - } - }; - make_decimal_type(precision, scale) - } - SQLDataType::Bytea => Ok(DataType::Binary), - SQLDataType::Interval => Ok(DataType::Interval(IntervalUnit::MonthDayNano)), - SQLDataType::Struct(fields, _) => { - let fields = fields - .iter() - .enumerate() - .map(|(idx, field)| { - let data_type = self.convert_data_type(&field.field_type)?; - let field_name = field.field_name.as_ref().map_or_else( - || Ident::new(format!("c{idx}")), - Clone::clone, - ); - Ok(Arc::new(Field::new( - self.ident_normalizer.normalize(field_name), - data_type, - true, - ))) - }) - .collect::>>()?; - Ok(DataType::Struct(Fields::from(fields))) - } - // https://github.com/apache/datafusion/issues/12644 - SQLDataType::JSON => Ok(DataType::Utf8), - SQLDataType::Custom(a, b) => match a.to_string().to_uppercase().as_str() { - "VARIANT" => Ok(DataType::Utf8), - "TIMESTAMP_NTZ" => { - let parsed_b: Option = b.iter().next().and_then(|s| s.parse().ok()); - match parsed_b { - Some(0) => Ok(DataType::Timestamp(TimeUnit::Second, None)), - Some(3) => Ok(DataType::Timestamp(TimeUnit::Millisecond, None)), - Some(6) => Ok(DataType::Timestamp(TimeUnit::Microsecond, None)), - Some(9) => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)), - _ => not_impl_err!("Unsupported SQL TIMESTAMP_NZT precision {parsed_b:?}"), - } - } - "NUMBER" => { - let (precision, scale) = match b.len() { - 0 => (None, None), - 1 => { - let precision = b[0].parse().map_err(|_| { - DataFusionError::Plan(format!("Invalid precision: {}", b[0])) - })?; - (Some(precision), None) - } - 2 => { - let precision = b[0].parse().map_err(|_| { - DataFusionError::Plan(format!("Invalid precision: {}", b[0])) - })?; - let scale = b[1].parse().map_err(|_| { - DataFusionError::Plan(format!("Invalid scale: {}", b[1])) - })?; - (Some(precision), Some(scale)) - } - _ => { - return Err(DataFusionError::Plan(format!( - "Invalid NUMBER type format: {b:?}" - ))); - } - }; - make_decimal_type(precision, scale) - } - _ => Ok(DataType::Utf8), - }, - // Explicitly list all other types so that if sqlparser - // adds/changes the `SQLDataType` the compiler will tell us on upgrade - // and avoid bugs like https://github.com/apache/datafusion/issues/3059 - SQLDataType::Nvarchar(_) - | SQLDataType::Uuid - | SQLDataType::Binary(_) - | SQLDataType::Varbinary(_) - | SQLDataType::Datetime(_) - | SQLDataType::Regclass - | SQLDataType::Array(_) - | SQLDataType::Enum(_, _) - | SQLDataType::Set(_) - | SQLDataType::MediumInt(_) - | SQLDataType::UnsignedMediumInt(_) - | SQLDataType::Character(_) - | SQLDataType::CharacterVarying(_) - | SQLDataType::CharVarying(_) - | SQLDataType::CharacterLargeObject(_) - | SQLDataType::CharLargeObject(_) - | SQLDataType::Time(Some(_), _) - | SQLDataType::Dec(_) - | SQLDataType::BigNumeric(_) - | SQLDataType::BigDecimal(_) - | SQLDataType::Clob(_) - | SQLDataType::Bytes(_) - | SQLDataType::Int128 - | SQLDataType::Int256 - | SQLDataType::UInt128 - | SQLDataType::UInt256 - | SQLDataType::Date32 - | SQLDataType::Datetime64(_, _) - | SQLDataType::FixedString(_) - | SQLDataType::Map(_, _) - | SQLDataType::Tuple(_) - | SQLDataType::Nested(_) - | SQLDataType::Union(_) - | SQLDataType::Nullable(_) - | SQLDataType::LowCardinality(_) - | SQLDataType::Trigger - | SQLDataType::JSONB - | SQLDataType::TinyBlob - | SQLDataType::MediumBlob - | SQLDataType::LongBlob - | SQLDataType::TinyText - | SQLDataType::MediumText - | SQLDataType::LongText - | SQLDataType::Bit(_) - | SQLDataType::BitVarying(_) - | SQLDataType::Unspecified => not_impl_err!("Unsupported SQL type {sql_type:?}"), - } - } - pub fn add_custom_metadata(&self, field: &mut Field, sql_type: &SQLDataType) { match sql_type { SQLDataType::JSON => { @@ -481,73 +232,9 @@ where } } - fn new_constraint_from_table_constraints( - constraints: &[TableConstraint], - df_schema: &DFSchemaRef, - ) -> Result { - let constraints = constraints - .iter() - .map(|c: &TableConstraint| match c { - TableConstraint::Unique { name, columns, .. } => { - let field_names = df_schema.field_names(); - // Get unique constraint indices in the schema: - let indices = columns - .iter() - .map(|u| { - let idx = field_names - .iter() - .position(|item| *item.to_lowercase() == u.value.to_lowercase()) - .ok_or_else(|| { - let name = name.as_ref().map_or(String::new(), |name| { - format!("with name '{name}' ") - }); - DataFusionError::Execution(format!( - "Column for unique constraint {}not found in schema: {}", - name, u.value - )) - })?; - Ok(idx) - }) - .collect::>>()?; - Ok(Constraint::Unique(indices)) - } - TableConstraint::PrimaryKey { columns, .. } => { - let field_names = df_schema.field_names(); - // Get primary key indices in the schema: - let indices = columns - .iter() - .map(|pk| { - let idx = field_names - .iter() - .position(|item| *item.to_lowercase() == pk.value.to_lowercase()) - .ok_or_else(|| { - DataFusionError::Execution(format!( - "Column for primary key not found in schema: {}", - pk.value - )) - })?; - Ok(idx) - }) - .collect::>>()?; - Ok(Constraint::PrimaryKey(indices)) - } - TableConstraint::ForeignKey { .. } => { - _plan_err!("Foreign key constraints are not currently supported") - } - TableConstraint::Check { .. } => { - _plan_err!("Check constraints are not currently supported") - } - TableConstraint::FulltextOrSpatial { .. } | TableConstraint::Index { .. } => { - _plan_err!("Indexes are not currently supported") - } - }) - .collect::>>()?; - Ok(Constraints::new_unverified(constraints)) - } - fn show_variable_to_plan(&self, variable: &[Ident]) -> Result { //println!("SHOW variable: {:?}", variable); - if !self.has_table("information_schema", "df_settings") { + if !self.inner.has_table("information_schema", "df_settings") { return plan_err!( "SHOW [VARIABLE] is not supported unless information_schema is enabled" ); @@ -555,7 +242,7 @@ where let verbose = variable .last() - .is_some_and(|s| ident_to_string(s) == "verbose"); + .is_some_and(|s| normalize_ident(s.to_owned()) == "verbose"); let mut variable_vec = variable.to_vec(); let mut columns: String = "name, value".to_owned(); @@ -619,126 +306,4 @@ where }, ) } - - fn has_table(&self, schema: &str, table: &str) -> bool { - let tables_reference = TableReference::Partial { - schema: schema.into(), - table: table.into(), - }; - self.provider.get_table_source(tables_reference).is_ok() - } -} - -fn calc_inline_constraints_from_columns(columns: &[ColumnDef]) -> Vec { - let mut constraints = vec![]; - for column in columns { - for ast::ColumnOptionDef { name, option } in &column.options { - match option { - ColumnOption::Unique { - is_primary: false, - characteristics, - } => constraints.push(TableConstraint::Unique { - name: name.clone(), - columns: vec![column.name.clone()], - characteristics: *characteristics, - index_name: None, - index_type_display: ast::KeyOrIndexDisplay::None, - index_type: None, - index_options: vec![], - nulls_distinct: ast::NullsDistinctOption::None, - }), - ColumnOption::Unique { - is_primary: true, - characteristics, - } => constraints.push(TableConstraint::PrimaryKey { - name: name.clone(), - columns: vec![column.name.clone()], - characteristics: *characteristics, - index_name: None, - index_type: None, - index_options: vec![], - }), - ColumnOption::ForeignKey { - foreign_table, - referred_columns, - on_delete, - on_update, - characteristics, - } => constraints.push(TableConstraint::ForeignKey { - name: name.clone(), - columns: vec![], - foreign_table: foreign_table.clone(), - referred_columns: referred_columns.clone(), - on_delete: *on_delete, - on_update: *on_update, - characteristics: *characteristics, - }), - ColumnOption::Check(expr) => constraints.push(TableConstraint::Check { - name: name.clone(), - expr: Box::new(expr.clone()), - }), - // Other options are not constraint related. - ColumnOption::Default(_) - | ColumnOption::Null - | ColumnOption::NotNull - | ColumnOption::DialectSpecific(_) - | ColumnOption::CharacterSet(_) - | ColumnOption::Generated { .. } - | ColumnOption::Comment(_) - | ColumnOption::Options(_) - | ColumnOption::Materialized(_) - | ColumnOption::Ephemeral(_) - | ColumnOption::Alias(_) - | ColumnOption::OnUpdate(_) - | ColumnOption::Identity(_) - | ColumnOption::OnConflict(_) - | ColumnOption::Policy(_) - | ColumnOption::Tags(_) => {} - } - } - } - constraints -} - -#[allow(clippy::cast_possible_truncation, clippy::as_conversions)] -pub fn make_decimal_type(precision: Option, scale: Option) -> Result { - // postgres like behavior - let (precision, scale) = match (precision, scale) { - (Some(p), Some(s)) => (p as u8, s as i8), - (Some(p), None) => (p as u8, 0), - (None, Some(_)) => return plan_err!("Cannot specify only scale for decimal data type"), - (None, None) => (DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE), - }; - - if precision == 0 || precision > DECIMAL256_MAX_PRECISION || scale.unsigned_abs() > precision { - plan_err!( - "Decimal(precision = {precision}, scale = {scale}) should satisfy `0 < precision <= 76`, and `scale <= precision`." - ) - } else if precision > DECIMAL128_MAX_PRECISION && precision <= DECIMAL256_MAX_PRECISION { - Ok(DataType::Decimal256(precision, scale)) - } else { - Ok(DataType::Decimal128(precision, scale)) - } -} - -fn ident_to_string(ident: &Ident) -> String { - normalize_ident(ident.to_owned()) -} - -fn object_name_to_string(object_name: &ObjectName) -> String { - object_name - .0 - .iter() - .map(ident_to_string) - .collect::>() - .join(".") -} - -// Normalize an owned identifier to a lowercase string unless the identifier is quoted. -#[must_use] -pub fn normalize_ident(id: Ident) -> String { - match id.quote_style { - Some(_) => id.value, - None => id.value.to_ascii_lowercase(), - } } diff --git a/crates/runtime/src/datafusion/type_planner.rs b/crates/runtime/src/datafusion/type_planner.rs new file mode 100644 index 000000000..068047784 --- /dev/null +++ b/crates/runtime/src/datafusion/type_planner.rs @@ -0,0 +1,67 @@ +use arrow::datatypes::{DataType, TimeUnit}; +use datafusion::common::Result; +use datafusion::logical_expr::planner::TypePlanner; +use datafusion::logical_expr::sqlparser::ast; +use datafusion::sql::sqlparser::ast::DataType as SQLDataType; +use datafusion::sql::utils::make_decimal_type; +use datafusion_common::{not_impl_err, DataFusionError}; + +#[derive(Debug)] +pub struct CustomTypePlanner {} + +impl TypePlanner for CustomTypePlanner { + fn plan_type(&self, sql_type: &ast::DataType) -> Result> { + match sql_type { + SQLDataType::Int32 => Ok(Some(DataType::Int32)), + SQLDataType::Int64 => Ok(Some(DataType::Int64)), + SQLDataType::UInt32 => Ok(Some(DataType::UInt32)), + SQLDataType::Blob(_) => Ok(Some(DataType::Binary)), + SQLDataType::Float(_) | SQLDataType::Float32 => Ok(Some(DataType::Float32)), + SQLDataType::Float64 => Ok(Some(DataType::Float64)), + + // https://github.com/apache/datafusion/issues/12644 + SQLDataType::JSON => Ok(Some(DataType::Utf8)), + SQLDataType::Custom(a, b) => match a.to_string().to_uppercase().as_str() { + "VARIANT" => Ok(Some(DataType::Utf8)), + "TIMESTAMP_NTZ" => { + let parsed_b: Option = b.iter().next().and_then(|s| s.parse().ok()); + match parsed_b { + Some(0) => Ok(Some(DataType::Timestamp(TimeUnit::Second, None))), + Some(3) => Ok(Some(DataType::Timestamp(TimeUnit::Millisecond, None))), + Some(6) => Ok(Some(DataType::Timestamp(TimeUnit::Microsecond, None))), + Some(9) => Ok(Some(DataType::Timestamp(TimeUnit::Nanosecond, None))), + _ => not_impl_err!("Unsupported SQL TIMESTAMP_NZT precision {parsed_b:?}"), + } + } + "NUMBER" => { + let (precision, scale) = match b.len() { + 0 => (None, None), + 1 => { + let precision = b[0].parse().map_err(|_| { + DataFusionError::Plan(format!("Invalid precision: {}", b[0])) + })?; + (Some(precision), None) + } + 2 => { + let precision = b[0].parse().map_err(|_| { + DataFusionError::Plan(format!("Invalid precision: {}", b[0])) + })?; + let scale = b[1].parse().map_err(|_| { + DataFusionError::Plan(format!("Invalid scale: {}", b[1])) + })?; + (Some(precision), Some(scale)) + } + _ => { + return Err(DataFusionError::Plan(format!( + "Invalid NUMBER type format: {b:?}" + ))); + } + }; + make_decimal_type(precision, scale).map(Some) + } + _ => Ok(None), + }, + _ => Ok(None), + } + } +}