diff --git a/Cargo.toml b/Cargo.toml index 32034f1ab9..1deb4e806c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ members = [ "crates/examples", "crates/iceberg", "crates/integrations/*", + "crates/sqllogictests", "crates/test_utils", ] exclude = ["bindings/python"] @@ -39,6 +40,7 @@ rust-version = "1.77.1" anyhow = "1.0.72" apache-avro = "0.17" array-init = "2" +arrow = { version = "52" } arrow-arith = { version = "52" } arrow-array = { version = "52" } arrow-ord = { version = "52" } @@ -56,6 +58,8 @@ bytes = "1.5" chrono = "0.4.34" ctor = "0.2.8" derive_builder = "0.20" +datafusion = { version = "41.0.0" } +datafusion-common = { version = "41.0.0" } either = "1" env_logger = "0.11.0" fnv = "1" @@ -84,6 +88,7 @@ serde_derive = "1" serde_json = "1" serde_repr = "0.1.16" serde_with = "3.4" +sqlparser = { version = "0.50.0", features = ["visitor"] } tempfile = "3.8" tokio = { version = "1", default-features = false } typed-builder = "0.19" diff --git a/crates/iceberg/src/writer/file_writer/track_writer.rs b/crates/iceberg/src/writer/file_writer/track_writer.rs index 6c60a1aa70..7b916aeb58 100644 --- a/crates/iceberg/src/writer/file_writer/track_writer.rs +++ b/crates/iceberg/src/writer/file_writer/track_writer.rs @@ -42,10 +42,9 @@ impl TrackWriter { impl FileWrite for TrackWriter { async fn write(&mut self, bs: Bytes) -> Result<()> { let size = bs.len(); - self.inner.write(bs).await.map(|v| { + self.inner.write(bs).await.inspect(|_| { self.written_size .fetch_add(size as i64, std::sync::atomic::Ordering::Relaxed); - v }) } diff --git a/crates/integrations/datafusion/Cargo.toml b/crates/integrations/datafusion/Cargo.toml index 87e809cec0..6a8cf00f92 100644 --- a/crates/integrations/datafusion/Cargo.toml +++ b/crates/integrations/datafusion/Cargo.toml @@ -31,7 +31,7 @@ keywords = ["iceberg", "integrations", "datafusion"] [dependencies] anyhow = { workspace = true } async-trait = { workspace = true } -datafusion = { version = "41.0.0" } +datafusion = { workspace = true } futures = { workspace = true } iceberg = { workspace = true } tokio = { workspace = true } diff --git a/crates/sqllogictests/Cargo.toml b/crates/sqllogictests/Cargo.toml new file mode 100644 index 0000000000..77e46354fd --- /dev/null +++ b/crates/sqllogictests/Cargo.toml @@ -0,0 +1,38 @@ +[package] +name = "sqllogictests" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +rust-version = { workspace = true } + +[dependencies] +anyhow = { workspace = true } +arrow = { workspace = true } +# For spark-connect-rs +arrow_51 = { version = "51", package = "arrow" } +async-trait = { workspace = true } +bigdecimal = "0.4.1" +datafusion = { workspace = true, default-features = true } +datafusion-common = { workspace = true, default-features = true } +env_logger = { workspace = true } +half = "2.4.1" +iceberg-catalog-rest = { path = "../catalog/rest" } +iceberg-datafusion = { path = "../integrations/datafusion" } +itertools = "0.13.0" +log = "0.4.22" +rust_decimal = { version = "1.27.0" } +spark-connect-rs = "0.0.1-beta.5" +sqllogictest = "0.22" +tokio = "1.38.0" +toml = "0.8.19" + +[dev-dependencies] +iceberg_test_utils = { path = "../test_utils", features = ["tests"] } +libtest-mimic = "0.7.3" + +[[test]] +harness = false +name = "sqllogictests" +path = "tests/sqllogictests.rs" diff --git a/crates/sqllogictests/src/display/conversion.rs b/crates/sqllogictests/src/display/conversion.rs new file mode 100644 index 0000000000..d68afaaf00 --- /dev/null +++ b/crates/sqllogictests/src/display/conversion.rs @@ -0,0 +1,82 @@ +use arrow::array::types::{Decimal128Type, Decimal256Type, DecimalType}; +use arrow::datatypes::i256; +use bigdecimal::BigDecimal; +use half::f16; +use rust_decimal::prelude::*; + +/// Represents a constant for NULL string in your database. +pub const NULL_STR: &str = "NULL"; + +pub(crate) fn bool_to_str(value: bool) -> String { + if value { + "true".to_string() + } else { + "false".to_string() + } +} + +pub(crate) fn varchar_to_str(value: &str) -> String { + if value.is_empty() { + "(empty)".to_string() + } else { + value.trim_end_matches('\n').to_string() + } +} + +pub(crate) fn f16_to_str(value: f16) -> String { + if value.is_nan() { + // The sign of NaN can be different depending on platform. + // So the string representation of NaN ignores the sign. + "NaN".to_string() + } else if value == f16::INFINITY { + "Infinity".to_string() + } else if value == f16::NEG_INFINITY { + "-Infinity".to_string() + } else { + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + } +} + +pub(crate) fn f32_to_str(value: f32) -> String { + if value.is_nan() { + // The sign of NaN can be different depending on platform. + // So the string representation of NaN ignores the sign. + "NaN".to_string() + } else if value == f32::INFINITY { + "Infinity".to_string() + } else if value == f32::NEG_INFINITY { + "-Infinity".to_string() + } else { + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + } +} + +pub(crate) fn f64_to_str(value: f64) -> String { + if value.is_nan() { + // The sign of NaN can be different depending on platform. + // So the string representation of NaN ignores the sign. + "NaN".to_string() + } else if value == f64::INFINITY { + "Infinity".to_string() + } else if value == f64::NEG_INFINITY { + "-Infinity".to_string() + } else { + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + } +} + +pub(crate) fn i128_to_str(value: i128, precision: &u8, scale: &i8) -> String { + big_decimal_to_str( + BigDecimal::from_str(&Decimal128Type::format_decimal(value, *precision, *scale)).unwrap(), + ) +} + +pub(crate) fn i256_to_str(value: i256, precision: &u8, scale: &i8) -> String { + big_decimal_to_str( + BigDecimal::from_str(&Decimal256Type::format_decimal(value, *precision, *scale)).unwrap(), + ) +} + +pub(crate) fn big_decimal_to_str(value: BigDecimal) -> String { + value.round(12).normalized().to_string() +} diff --git a/crates/sqllogictests/src/display/conversion_51.rs b/crates/sqllogictests/src/display/conversion_51.rs new file mode 100644 index 0000000000..b1243e2af7 --- /dev/null +++ b/crates/sqllogictests/src/display/conversion_51.rs @@ -0,0 +1,82 @@ +use arrow_51::array::types::{Decimal128Type, Decimal256Type, DecimalType}; +use arrow_51::datatypes::i256; +use bigdecimal::BigDecimal; +use half::f16; +use rust_decimal::prelude::*; + +/// Represents a constant for NULL string in your database. +pub const NULL_STR: &str = "NULL"; + +pub(crate) fn bool_to_str(value: bool) -> String { + if value { + "true".to_string() + } else { + "false".to_string() + } +} + +pub(crate) fn varchar_to_str(value: &str) -> String { + if value.is_empty() { + "(empty)".to_string() + } else { + value.trim_end_matches('\n').to_string() + } +} + +pub(crate) fn f16_to_str(value: f16) -> String { + if value.is_nan() { + // The sign of NaN can be different depending on platform. + // So the string representation of NaN ignores the sign. + "NaN".to_string() + } else if value == f16::INFINITY { + "Infinity".to_string() + } else if value == f16::NEG_INFINITY { + "-Infinity".to_string() + } else { + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + } +} + +pub(crate) fn f32_to_str(value: f32) -> String { + if value.is_nan() { + // The sign of NaN can be different depending on platform. + // So the string representation of NaN ignores the sign. + "NaN".to_string() + } else if value == f32::INFINITY { + "Infinity".to_string() + } else if value == f32::NEG_INFINITY { + "-Infinity".to_string() + } else { + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + } +} + +pub(crate) fn f64_to_str(value: f64) -> String { + if value.is_nan() { + // The sign of NaN can be different depending on platform. + // So the string representation of NaN ignores the sign. + "NaN".to_string() + } else if value == f64::INFINITY { + "Infinity".to_string() + } else if value == f64::NEG_INFINITY { + "-Infinity".to_string() + } else { + big_decimal_to_str(BigDecimal::from_str(&value.to_string()).unwrap()) + } +} + +pub(crate) fn i128_to_str(value: i128, precision: &u8, scale: &i8) -> String { + big_decimal_to_str( + BigDecimal::from_str(&Decimal128Type::format_decimal(value, *precision, *scale)).unwrap(), + ) +} + +pub(crate) fn i256_to_str(value: i256, precision: &u8, scale: &i8) -> String { + big_decimal_to_str( + BigDecimal::from_str(&Decimal256Type::format_decimal(value, *precision, *scale)).unwrap(), + ) +} + +pub(crate) fn big_decimal_to_str(value: BigDecimal) -> String { + value.round(12).normalized().to_string() +} diff --git a/crates/sqllogictests/src/display/mod.rs b/crates/sqllogictests/src/display/mod.rs new file mode 100644 index 0000000000..55fe9d2ad9 --- /dev/null +++ b/crates/sqllogictests/src/display/mod.rs @@ -0,0 +1,4 @@ +pub mod conversion; +pub mod conversion_51; +pub mod normalize; +pub mod normalize_51; diff --git a/crates/sqllogictests/src/display/normalize.rs b/crates/sqllogictests/src/display/normalize.rs new file mode 100644 index 0000000000..5e975ef375 --- /dev/null +++ b/crates/sqllogictests/src/display/normalize.rs @@ -0,0 +1,203 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use anyhow::anyhow; +use arrow::array::{ + ArrayRef, BooleanArray, Decimal128Array, Decimal256Array, Float16Array, Float32Array, + Float64Array, LargeStringArray, RecordBatch, StringArray, StringViewArray, +}; +use arrow::datatypes::{DataType, Fields}; +use arrow::util::display::ArrayFormatter; +use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; + +use crate::display::conversion::*; +use crate::engine::output::DFColumnType; + +/// Converts `batches` to a result as expected by sqllogicteset. +pub(crate) fn convert_batches(batches: Vec) -> anyhow::Result>> { + if batches.is_empty() { + Ok(vec![]) + } else { + let schema = batches[0].schema(); + let mut rows = vec![]; + for batch in batches { + // Verify schema + if !schema.contains(&batch.schema()) { + return Err(anyhow!( + "Schema mismatch. Previously had\n{:#?}\n\nGot:\n{:#?}", + &schema, + batch.schema() + )); + } + + let new_rows = convert_batch(batch)?.into_iter().flat_map(expand_row); + rows.extend(new_rows); + } + Ok(rows) + } +} + +/// special case rows that have newlines in them (like explain plans) +// +/// Transform inputs like: +/// ```text +/// [ +/// "logical_plan", +/// "Sort: d.b ASC NULLS LAST\n Projection: d.b, MAX(d.a) AS max_a", +/// ] +/// ``` +/// +/// Into one cell per line, adding lines if necessary +/// ```text +/// [ +/// "logical_plan", +/// ] +/// [ +/// "Sort: d.b ASC NULLS LAST", +/// ] +/// [ <--- newly added row +/// "|-- Projection: d.b, MAX(d.a) AS max_a", +/// ] +/// ``` +fn expand_row(mut row: Vec) -> impl Iterator> { + use std::iter::once; + + use itertools::Either; + + // check last cell + if let Some(cell) = row.pop() { + let lines: Vec<_> = cell.split('\n').collect(); + + // no newlines in last cell + if lines.len() < 2 { + row.push(cell); + return Either::Left(once(row)); + } + + // form new rows with each additional line + let new_lines: Vec<_> = lines + .into_iter() + .enumerate() + .map(|(idx, l)| { + // replace any leading spaces with '-' as + // `sqllogictests` ignores whitespace differences + // + // See https://github.com/apache/datafusion/issues/6328 + let content = l.trim_start(); + let new_prefix = "-".repeat(l.len() - content.len()); + // maintain for each line a number, so + // reviewing explain result changes is easier + let line_num = idx + 1; + vec![format!("{line_num:02}){new_prefix}{content}")] + }) + .collect(); + + Either::Right(once(row).chain(new_lines)) + } else { + Either::Left(once(row)) + } +} + +/// Convert a single batch to a `Vec>` for comparison +fn convert_batch(batch: RecordBatch) -> anyhow::Result>> { + (0..batch.num_rows()) + .map(|row| { + batch + .columns() + .iter() + .map(|col| cell_to_string(col, row)) + .collect::>>() + }) + .collect() +} + +macro_rules! get_row_value { + ($array_type:ty, $column: ident, $row: ident) => {{ + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + + array.value($row) + }}; +} + +/// Normalizes the content of a single cell in RecordBatch prior to printing. +/// +/// This is to make the output comparable to the semi-standard .slt format +/// +/// Normalizations applied to [NULL Values and empty strings] +/// +/// [NULL Values and empty strings]: https://duckdb.org/dev/sqllogictest/result_verification#null-values-and-empty-strings +/// +/// Floating numbers are rounded to have a consistent representation with the Postgres runner. +pub fn cell_to_string(col: &ArrayRef, row: usize) -> anyhow::Result { + if !col.is_valid(row) { + // represent any null value with the string "NULL" + Ok(NULL_STR.to_string()) + } else { + match col.data_type() { + DataType::Null => Ok(NULL_STR.to_string()), + DataType::Boolean => Ok(bool_to_str(get_row_value!(BooleanArray, col, row))), + DataType::Float16 => Ok(f16_to_str(get_row_value!(Float16Array, col, row))), + DataType::Float32 => Ok(f32_to_str(get_row_value!(Float32Array, col, row))), + DataType::Float64 => Ok(f64_to_str(get_row_value!(Float64Array, col, row))), + DataType::Decimal128(precision, scale) => { + let value = get_row_value!(Decimal128Array, col, row); + Ok(i128_to_str(value, precision, scale)) + } + DataType::Decimal256(precision, scale) => { + let value = get_row_value!(Decimal256Array, col, row); + Ok(i256_to_str(value, precision, scale)) + } + DataType::LargeUtf8 => Ok(varchar_to_str(get_row_value!(LargeStringArray, col, row))), + DataType::Utf8 => Ok(varchar_to_str(get_row_value!(StringArray, col, row))), + DataType::Utf8View => Ok(varchar_to_str(get_row_value!(StringViewArray, col, row))), + _ => { + let f = ArrayFormatter::try_new(col.as_ref(), &DEFAULT_FORMAT_OPTIONS); + Ok(f.unwrap().value(row).to_string()) + } + } + } +} + +/// Converts columns to a result as expected by sqllogicteset. +pub(crate) fn convert_schema_to_types(columns: &Fields) -> Vec { + columns + .iter() + .map(|f| f.data_type()) + .map(|data_type| match data_type { + DataType::Boolean => DFColumnType::Boolean, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => DFColumnType::Integer, + DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => DFColumnType::Float, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => DFColumnType::Text, + DataType::Date32 | DataType::Date64 | DataType::Time32(_) | DataType::Time64(_) => { + DFColumnType::DateTime + } + DataType::Timestamp(_, _) => DFColumnType::Timestamp, + _ => DFColumnType::Another, + }) + .collect() +} diff --git a/crates/sqllogictests/src/display/normalize_51.rs b/crates/sqllogictests/src/display/normalize_51.rs new file mode 100644 index 0000000000..f48744f22f --- /dev/null +++ b/crates/sqllogictests/src/display/normalize_51.rs @@ -0,0 +1,205 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use anyhow::anyhow; +use arrow_51::array::{ + ArrayRef, BooleanArray, Decimal128Array, Decimal256Array, Float16Array, Float32Array, + Float64Array, LargeStringArray, RecordBatch, StringArray, StringViewArray, +}; +use arrow_51::datatypes::{DataType, Fields}; +use arrow_51::util::display::{ArrayFormatter, DurationFormat, FormatOptions}; + +use crate::display::conversion_51::*; +use crate::engine::output::DFColumnType; + +const DEFAULT_FORMAT_OPTIONS: FormatOptions<'static> = + FormatOptions::new().with_duration_format(DurationFormat::Pretty); + +/// Converts `batches` to a result as expected by sqllogicteset. +pub(crate) fn convert_batches(batches: Vec) -> anyhow::Result>> { + if batches.is_empty() { + Ok(vec![]) + } else { + let schema = batches[0].schema(); + let mut rows = vec![]; + for batch in batches { + // Verify schema + if !schema.contains(&batch.schema()) { + return Err(anyhow!( + "Schema mismatch. Previously had\n{:#?}\n\nGot:\n{:#?}", + &schema, + batch.schema() + )); + } + + let new_rows = convert_batch(batch)?.into_iter().flat_map(expand_row); + rows.extend(new_rows); + } + Ok(rows) + } +} + +/// special case rows that have newlines in them (like explain plans) +// +/// Transform inputs like: +/// ```text +/// [ +/// "logical_plan", +/// "Sort: d.b ASC NULLS LAST\n Projection: d.b, MAX(d.a) AS max_a", +/// ] +/// ``` +/// +/// Into one cell per line, adding lines if necessary +/// ```text +/// [ +/// "logical_plan", +/// ] +/// [ +/// "Sort: d.b ASC NULLS LAST", +/// ] +/// [ <--- newly added row +/// "|-- Projection: d.b, MAX(d.a) AS max_a", +/// ] +/// ``` +fn expand_row(mut row: Vec) -> impl Iterator> { + use std::iter::once; + + use itertools::Either; + + // check last cell + if let Some(cell) = row.pop() { + let lines: Vec<_> = cell.split('\n').collect(); + + // no newlines in last cell + if lines.len() < 2 { + row.push(cell); + return Either::Left(once(row)); + } + + // form new rows with each additional line + let new_lines: Vec<_> = lines + .into_iter() + .enumerate() + .map(|(idx, l)| { + // replace any leading spaces with '-' as + // `sqllogictests` ignores whitespace differences + // + // See https://github.com/apache/datafusion/issues/6328 + let content = l.trim_start(); + let new_prefix = "-".repeat(l.len() - content.len()); + // maintain for each line a number, so + // reviewing explain result changes is easier + let line_num = idx + 1; + vec![format!("{line_num:02}){new_prefix}{content}")] + }) + .collect(); + + Either::Right(once(row).chain(new_lines)) + } else { + Either::Left(once(row)) + } +} + +/// Convert a single batch to a `Vec>` for comparison +fn convert_batch(batch: RecordBatch) -> anyhow::Result>> { + (0..batch.num_rows()) + .map(|row| { + batch + .columns() + .iter() + .map(|col| cell_to_string(col, row)) + .collect::>>() + }) + .collect() +} + +macro_rules! get_row_value { + ($array_type:ty, $column: ident, $row: ident) => {{ + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + + array.value($row) + }}; +} + +/// Normalizes the content of a single cell in RecordBatch prior to printing. +/// +/// This is to make the output comparable to the semi-standard .slt format +/// +/// Normalizations applied to [NULL Values and empty strings] +/// +/// [NULL Values and empty strings]: https://duckdb.org/dev/sqllogictest/result_verification#null-values-and-empty-strings +/// +/// Floating numbers are rounded to have a consistent representation with the Postgres runner. +pub fn cell_to_string(col: &ArrayRef, row: usize) -> anyhow::Result { + if !col.is_valid(row) { + // represent any null value with the string "NULL" + Ok(NULL_STR.to_string()) + } else { + match col.data_type() { + DataType::Null => Ok(NULL_STR.to_string()), + DataType::Boolean => Ok(bool_to_str(get_row_value!(BooleanArray, col, row))), + DataType::Float16 => Ok(f16_to_str(get_row_value!(Float16Array, col, row))), + DataType::Float32 => Ok(f32_to_str(get_row_value!(Float32Array, col, row))), + DataType::Float64 => Ok(f64_to_str(get_row_value!(Float64Array, col, row))), + DataType::Decimal128(precision, scale) => { + let value = get_row_value!(Decimal128Array, col, row); + Ok(i128_to_str(value, precision, scale)) + } + DataType::Decimal256(precision, scale) => { + let value = get_row_value!(Decimal256Array, col, row); + Ok(i256_to_str(value, precision, scale)) + } + DataType::LargeUtf8 => Ok(varchar_to_str(get_row_value!(LargeStringArray, col, row))), + DataType::Utf8 => Ok(varchar_to_str(get_row_value!(StringArray, col, row))), + DataType::Utf8View => Ok(varchar_to_str(get_row_value!(StringViewArray, col, row))), + _ => { + let f = ArrayFormatter::try_new(col.as_ref(), &DEFAULT_FORMAT_OPTIONS); + Ok(f.unwrap().value(row).to_string()) + } + } + } +} + +/// Converts columns to a result as expected by sqllogicteset. +pub(crate) fn convert_schema_to_types(columns: &Fields) -> Vec { + columns + .iter() + .map(|f| f.data_type()) + .map(|data_type| match data_type { + DataType::Boolean => DFColumnType::Boolean, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 => DFColumnType::Integer, + DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => DFColumnType::Float, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => DFColumnType::Text, + DataType::Date32 | DataType::Date64 | DataType::Time32(_) | DataType::Time64(_) => { + DFColumnType::DateTime + } + DataType::Timestamp(_, _) => DFColumnType::Timestamp, + _ => DFColumnType::Another, + }) + .collect() +} diff --git a/crates/sqllogictests/src/engine/datafusion.rs b/crates/sqllogictests/src/engine/datafusion.rs new file mode 100644 index 0000000000..19830fdc3b --- /dev/null +++ b/crates/sqllogictests/src/engine/datafusion.rs @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Duration; + +use anyhow::anyhow; +use arrow::array::RecordBatch; +use async_trait::async_trait; +use datafusion::catalog::CatalogProvider; +use datafusion::physical_plan::common::collect; +use datafusion::physical_plan::execute_stream; +use datafusion::prelude::{SessionConfig, SessionContext}; +use iceberg_catalog_rest::{RestCatalog, RestCatalogConfig}; +use iceberg_datafusion::IcebergCatalogProvider; +use sqllogictest::{AsyncDB, DBOutput}; +use toml::Table; + +use crate::display::normalize; +use crate::engine::output::{DFColumnType, DFOutput}; +use crate::error::{Error, Result}; + +pub struct DataFusionEngine { + ctx: SessionContext, +} + +impl Default for DataFusionEngine { + fn default() -> Self { + let config = SessionConfig::new().with_target_partitions(4); + + let ctx = SessionContext::new_with_config(config); + + Self { ctx } + } +} + +#[async_trait] +impl AsyncDB for DataFusionEngine { + type Error = Error; + type ColumnType = DFColumnType; + + async fn run(&mut self, sql: &str) -> Result { + Ok(run_query(&self.ctx, sql).await?) + } + + /// Engine name of current database. + fn engine_name(&self) -> &str { + "DataFusion" + } + + /// [`DataFusionEngine`] calls this function to perform sleep. + /// + /// The default implementation is `std::thread::sleep`, which is universal to any async runtime + /// but would block the current thread. If you are running in tokio runtime, you should override + /// this by `tokio::time::sleep`. + async fn sleep(dur: Duration) { + tokio::time::sleep(dur).await; + } +} + +async fn run_query(ctx: &SessionContext, sql: impl Into) -> anyhow::Result { + let df = ctx.sql(sql.into().as_str()).await?; + let task_ctx = Arc::new(df.task_ctx()); + let plan = df.create_physical_plan().await?; + + let stream = execute_stream(plan, task_ctx)?; + let types = normalize::convert_schema_to_types(stream.schema().fields()); + let results: Vec = collect(stream).await?; + let rows = normalize::convert_batches(results)?; + + if rows.is_empty() && types.is_empty() { + Ok(DBOutput::StatementComplete(0)) + } else { + Ok(DBOutput::Rows { types, rows }) + } +} + +impl DataFusionEngine { + pub async fn new(configs: &Table) -> Result { + let config = SessionConfig::new().with_target_partitions(4); + + let ctx = SessionContext::new_with_config(config); + ctx.register_catalog("demo", Self::create_catalog(configs).await?); + + Ok(Self { ctx }) + } + + async fn create_catalog(configs: &Table) -> anyhow::Result> { + let rest_catalog_url = configs + .get("url") + .ok_or_else(|| anyhow!("url not found datafusion engine!"))? + .as_str() + .ok_or_else(|| anyhow!("url is not str"))?; + + let rest_catalog_config = RestCatalogConfig::builder() + .uri(rest_catalog_url.to_string()) + .props(HashMap::from([ + ("s3.endpoint".to_string(), "http://localhost:9000".to_string()), + ("s3.access-key-id".to_string(), "admin".to_string()), + ("s3.secret-access-key".to_string(), "password".to_string()), + ("s3.region".to_string(), "us-east-1".to_string()), + ])) + .build(); + + let rest_catalog = RestCatalog::new(rest_catalog_config); + + Ok(Arc::new( + IcebergCatalogProvider::try_new(Arc::new(rest_catalog)).await?, + )) + } +} diff --git a/crates/sqllogictests/src/engine/mod.rs b/crates/sqllogictests/src/engine/mod.rs new file mode 100644 index 0000000000..7e31559243 --- /dev/null +++ b/crates/sqllogictests/src/engine/mod.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use anyhow::anyhow; +pub use datafusion::*; +use sqllogictest::{strict_column_validator, AsyncDB, MakeConnection, Runner}; +use toml::Table; + +pub mod output; + +mod spark; +pub use spark::*; + +mod datafusion; + +use crate::error::Result; + +#[derive(Clone)] +pub enum Engine { + DataFusion(Arc), + SparkSQL(Arc
), +} + +impl Engine { + pub async fn new(typ: &str, configs: &Table) -> Result { + let configs = Arc::new(configs.clone()); + match typ { + "spark" => Ok(Engine::SparkSQL(configs)), + "datafusion" => Ok(Engine::DataFusion(configs)), + other => Err(anyhow!("Unknown engine type: {other}").into()), + } + } + + pub async fn run_slt_file(self, slt_file: impl Into) -> anyhow::Result<()> { + let absolute_file = format!( + "{}/testdata/slts/{}", + env!("CARGO_MANIFEST_DIR"), + slt_file.into() + ); + + match self { + Engine::DataFusion(configs) => { + let configs = configs.clone(); + let runner = Runner::new(|| async { DataFusionEngine::new(&configs).await }); + Self::run_with_runner(runner, absolute_file).await + } + Engine::SparkSQL(configs) => { + let configs = configs.clone(); + let runner = Runner::new(|| async { SparkSqlEngine::new(&configs).await }); + Self::run_with_runner(runner, absolute_file).await + } + } + } + + async fn run_with_runner>( + mut runner: Runner, + slt_file: String, + ) -> anyhow::Result<()> { + runner.with_column_validator(strict_column_validator); + Ok(runner.run_file_async(slt_file).await?) + } +} diff --git a/crates/sqllogictests/src/engine/output.rs b/crates/sqllogictests/src/engine/output.rs new file mode 100644 index 0000000000..24299856e0 --- /dev/null +++ b/crates/sqllogictests/src/engine/output.rs @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use sqllogictest::{ColumnType, DBOutput}; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum DFColumnType { + Boolean, + DateTime, + Integer, + Float, + Text, + Timestamp, + Another, +} + +impl ColumnType for DFColumnType { + fn from_char(value: char) -> Option { + match value { + 'B' => Some(Self::Boolean), + 'D' => Some(Self::DateTime), + 'I' => Some(Self::Integer), + 'P' => Some(Self::Timestamp), + 'R' => Some(Self::Float), + 'T' => Some(Self::Text), + _ => Some(Self::Another), + } + } + + fn to_char(&self) -> char { + match self { + Self::Boolean => 'B', + Self::DateTime => 'D', + Self::Integer => 'I', + Self::Timestamp => 'P', + Self::Float => 'R', + Self::Text => 'T', + Self::Another => '?', + } + } +} + +pub(crate) type DFOutput = DBOutput; diff --git a/crates/sqllogictests/src/engine/spark.rs b/crates/sqllogictests/src/engine/spark.rs new file mode 100644 index 0000000000..f180219ce1 --- /dev/null +++ b/crates/sqllogictests/src/engine/spark.rs @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::time::Duration; + +use anyhow::anyhow; +use async_trait::async_trait; +use spark_connect_rs::{SparkSession, SparkSessionBuilder}; +use sqllogictest::{AsyncDB, DBOutput}; +use toml::Table; + +use crate::display::normalize_51; +use crate::engine::output::DFColumnType; +use crate::error::{Error, Result}; + +/// SparkSql engine implementation for sqllogictests. +pub struct SparkSqlEngine { + session: SparkSession, +} + +#[async_trait] +impl AsyncDB for SparkSqlEngine { + type Error = Error; + type ColumnType = DFColumnType; + + async fn run(&mut self, sql: &str) -> Result> { + let results = self + .session + .sql(sql) + .await + .map_err(|e| anyhow!(e))? + .collect() + .await + .map_err(|e| anyhow!(e))?; + let types = normalize_51::convert_schema_to_types(results.schema().fields()); + let rows = normalize_51::convert_batches(vec![results])?; + + if rows.is_empty() && types.is_empty() { + Ok(DBOutput::StatementComplete(0)) + } else { + Ok(DBOutput::Rows { types, rows }) + } + } + + /// Engine name of current database. + fn engine_name(&self) -> &str { + "SparkConnect" + } + + /// [`DataFusionEngine`] calls this function to perform sleep. + /// + /// The default implementation is `std::thread::sleep`, which is universal to any async runtime + /// but would block the current thread. If you are running in tokio runtime, you should override + /// this by `tokio::time::sleep`. + async fn sleep(dur: Duration) { + tokio::time::sleep(dur).await; + } +} + +impl SparkSqlEngine { + pub async fn new(configs: &Table) -> Result { + let url = configs + .get("url") + .ok_or_else(|| anyhow!("url property doesn't exist for spark engine"))? + .as_str() + .ok_or_else(|| anyhow!("url property is not a string for spark engine"))?; + + let session = SparkSessionBuilder::remote(url) + .app_name("SparkConnect") + .build() + .await + .map_err(|e| anyhow!(e))?; + + Ok(Self { session }) + } +} diff --git a/crates/sqllogictests/src/error.rs b/crates/sqllogictests/src/error.rs new file mode 100644 index 0000000000..a08f758b13 --- /dev/null +++ b/crates/sqllogictests/src/error.rs @@ -0,0 +1,28 @@ +use std::fmt::{Debug, Display, Formatter}; + +pub struct Error(pub anyhow::Error); +pub type Result = std::result::Result; + +impl Debug for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{:?}", self.0) + } +} + +impl Display for Error { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + self.0.source() + } +} + +impl From for Error { + fn from(value: anyhow::Error) -> Self { + Self(value) + } +} diff --git a/crates/sqllogictests/src/lib.rs b/crates/sqllogictests/src/lib.rs new file mode 100644 index 0000000000..d907257ca9 --- /dev/null +++ b/crates/sqllogictests/src/lib.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This lib contains codes copied from +// [Apache Datafusion](https://github.com/apache/datafusion/tree/main/datafusion/sqllogictest) +mod display; +mod engine; +mod error; +pub mod schedule; + +pub use error::*; diff --git a/crates/sqllogictests/src/schedule.rs b/crates/sqllogictests/src/schedule.rs new file mode 100644 index 0000000000..eceb5d9b77 --- /dev/null +++ b/crates/sqllogictests/src/schedule.rs @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::collections::HashMap; +use std::fs::read_to_string; +use std::path::Path; + +use anyhow::anyhow; +use itertools::Itertools; +use toml::{Table, Value}; + +use crate::engine::Engine; + +/// Schedule of engines to run tests. +pub struct Schedule { + /// Map of engine names to engine instances. + engines: HashMap, + /// List of steps to run, each step is a sql file. + steps: Vec, +} + +pub struct Step { + /// Name of engine to execute. + engine_name: String, + /// Name of sql file. + sql: String, +} + +impl Schedule { + pub async fn parse>(schedule_def_file: P) -> anyhow::Result { + let content = read_to_string(schedule_def_file)?; + let toml_value = content.parse::()?; + let toml_table = toml_value + .as_table() + .ok_or_else(|| anyhow::anyhow!("Schedule file must be a TOML table"))?; + + let engines = Schedule::parse_engines(toml_table).await?; + let steps = Schedule::parse_steps(toml_table).await?; + + Ok(Self { engines, steps }) + } + + async fn parse_engines(table: &Table) -> anyhow::Result> { + let engines = table + .get("engines") + .ok_or_else(|| anyhow::anyhow!("Schedule file must have an 'engines' table"))? + .as_table() + .ok_or_else(|| anyhow::anyhow!("'engines' must be a table"))?; + + let mut result = HashMap::new(); + for (name, engine_config) in engines { + let engine_configs = engine_config + .as_table() + .ok_or_else(|| anyhow::anyhow!("Config of engine {name} is not a table"))?; + + let typ = engine_configs + .get("type") + .ok_or_else(|| anyhow::anyhow!("Engine {name} doesn't have a 'type' field"))? + .as_str() + .ok_or_else(|| anyhow::anyhow!("Engine {name} type must be a string"))?; + + let engine = Engine::new(typ, engine_configs).await?; + + result.insert(name.clone(), engine); + } + + Ok(result) + } + + async fn parse_steps(table: &Table) -> anyhow::Result> { + let steps = table + .get("steps") + .ok_or_else(|| anyhow!("steps not found"))? + .as_array() + .ok_or_else(|| anyhow!("steps is not array"))?; + + steps.iter().map(Schedule::parse_step).try_collect() + } + + fn parse_step(value: &Value) -> anyhow::Result { + let t = value + .as_table() + .ok_or_else(|| anyhow!("Step must be a table!"))?; + + let engine_name = t + .get("engine") + .ok_or_else(|| anyhow!("Property engine is missing in step"))? + .as_str() + .ok_or_else(|| anyhow!("Property engine is not a string in step"))? + .to_string(); + + let sql = t + .get("sql") + .ok_or_else(|| anyhow!("Property sql is missing in step"))? + .as_str() + .ok_or_else(|| anyhow!("Property sqlis not a string in step"))? + .to_string(); + + Ok(Step { engine_name, sql }) + } + + pub async fn run(self) -> anyhow::Result<()> { + for step_idx in 0..self.steps.len() { + self.run_step(step_idx).await?; + } + + Ok(()) + } + + async fn run_step(&self, step_index: usize) -> anyhow::Result<()> { + let step = &self.steps[step_index]; + + let engine = self + .engines + .get(&step.engine_name) + .ok_or_else(|| anyhow!("Engine {} not found!", step.engine_name))? + .clone(); + + engine.run_slt_file(&step.sql).await + } +} diff --git a/crates/sqllogictests/testdata/docker/docker-compose.yaml b/crates/sqllogictests/testdata/docker/docker-compose.yaml new file mode 100644 index 0000000000..9d24069b4c --- /dev/null +++ b/crates/sqllogictests/testdata/docker/docker-compose.yaml @@ -0,0 +1,88 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +services: + rest: + image: tabulario/iceberg-rest:0.10.0 + environment: + - AWS_ACCESS_KEY_ID=admin + - AWS_SECRET_ACCESS_KEY=password + - AWS_REGION=us-east-1 + - CATALOG_CATOLOG__IMPL=org.apache.iceberg.jdbc.JdbcCatalog + - CATALOG_URI=jdbc:sqlite:file:/tmp/iceberg_rest_mode=memory + - CATALOG_WAREHOUSE=s3://icebergdata/demo + - CATALOG_IO__IMPL=org.apache.iceberg.aws.s3.S3FileIO + - CATALOG_S3_ENDPOINT=http://minio:9000 + depends_on: + - minio +# networks: +# host: +# aliases: +# - icebergdata.minio + ports: + - "8181:8181" + + minio: + image: minio/minio:RELEASE.2024-03-07T00-43-48Z + environment: + - MINIO_ROOT_USER=admin + - MINIO_ROOT_PASSWORD=password + - MINIO_DOMAIN=minio + hostname: icebergdata.minio + ports: + - "9001:9001" + - "9000:9000" + command: ["server", "/data", "--console-address", ":9001"] + + mc: + depends_on: + - minio + image: minio/mc:RELEASE.2024-03-07T00-31-49Z + environment: + - AWS_ACCESS_KEY_ID=admin + - AWS_SECRET_ACCESS_KEY=password + - AWS_REGION=us-east-1 + entrypoint: > + /bin/sh -c " until (/usr/bin/mc config host add minio http://minio:9000 admin password) do echo '...waiting...' && sleep 1; done; /usr/bin/mc rm -r --force minio/icebergdata; /usr/bin/mc mb minio/icebergdata; /usr/bin/mc policy set public minio/icebergdata; tail -f /dev/null " + + spark: + depends_on: + - rest + - minio + image: apache/spark:3.5.2-java17 + environment: + - AWS_ACCESS_KEY_ID=admin + - AWS_SECRET_ACCESS_KEY=password + - AWS_REGION=us-east-1 + - SPARK_HOME=/opt/spark + - PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/opt/spark/bin:/opt/spark/sbin +# - HTTP_PROXY=http://host.docker.internal:7890 +# - HTTPS_PROXY=http://host.docker.internal:7890 + user: root + links: + - minio:icebergdata.minio + ports: + - "15002:15002" + healthcheck: + test: netstat -ltn | grep -c 15002 + interval: 1s + retries: 1200 + volumes: + - ./spark:/spark-script + entrypoint: [ "/spark-script/spark-connect-server.sh" ] + diff --git a/crates/sqllogictests/testdata/docker/spark/spark-connect-server.sh b/crates/sqllogictests/testdata/docker/spark/spark-connect-server.sh new file mode 100755 index 0000000000..31c064c4e5 --- /dev/null +++ b/crates/sqllogictests/testdata/docker/spark/spark-connect-server.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +set -ex + +SPARK_VERSION="3.5.2" +ICEBERG_VERSION="1.6.0" + +PACKAGES="org.apache.iceberg:iceberg-spark-runtime-3.5_2.12:$ICEBERG_VERSION" +PACKAGES="$PACKAGES,org.apache.iceberg:iceberg-aws-bundle:$ICEBERG_VERSION" +PACKAGES="$PACKAGES,org.apache.spark:spark-connect_2.12:$SPARK_VERSION" + +/opt/spark/sbin/start-connect-server.sh \ + --packages $PACKAGES \ + --master local[3] \ + --conf spark.driver.extraJavaOptions="-Dlog4j.configuration=file:///spark-script/log4j2.properties" \ + --conf spark.driver.bindAddress=0.0.0.0 \ + --conf spark.sql.catalog.demo=org.apache.iceberg.spark.SparkCatalog \ + --conf spark.sql.extensions=org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions \ + --conf spark.sql.catalog.demo.catalog-impl=org.apache.iceberg.rest.RESTCatalog \ + --conf spark.sql.catalog.demo.uri=http://rest:8181 \ + --conf spark.sql.catalog.demo.s3.endpoint=http://minio:9000 \ + --conf spark.sql.catalog.demo.s3.path.style.access=true \ + --conf spark.sql.catalog.demo.s3.access.key=admin \ + --conf spark.sql.catalog.demo.s3.secret.key=password \ + --conf spark.sql.defaultCatalog=demo + +tail -f /opt/spark/logs/spark*.out \ No newline at end of file diff --git a/crates/sqllogictests/testdata/schedules/demo.toml b/crates/sqllogictests/testdata/schedules/demo.toml new file mode 100644 index 0000000000..ad4773d248 --- /dev/null +++ b/crates/sqllogictests/testdata/schedules/demo.toml @@ -0,0 +1,12 @@ +[engines] +spark = { type = "spark", url = "sc://localhost:15002" } +df = { type = "datafusion", url = "http://localhost:8181" } + +[[steps]] +engine = "spark" +sql = "demo/prepare.slt" + +[[steps]] +engine = "df" +sql = "demo/verify.slt" + diff --git a/crates/sqllogictests/testdata/slts/demo/prepare.slt b/crates/sqllogictests/testdata/slts/demo/prepare.slt new file mode 100644 index 0000000000..c7fb278d56 --- /dev/null +++ b/crates/sqllogictests/testdata/slts/demo/prepare.slt @@ -0,0 +1,11 @@ +statement ok +CREATE SCHEMA IF NOT EXISTS s1 ; + +statement ok +USE SCHEMA s1; + +statement ok +CREATE TABLE t1 (id INTEGER); + +statement ok +INSERT INTO t1 VALUES (1), (2), (3); diff --git a/crates/sqllogictests/testdata/slts/demo/verify.slt b/crates/sqllogictests/testdata/slts/demo/verify.slt new file mode 100644 index 0000000000..db0b1420b1 --- /dev/null +++ b/crates/sqllogictests/testdata/slts/demo/verify.slt @@ -0,0 +1,6 @@ +query I rowsort +SELECT * FROM demo.s1.t1; +---- +1 +2 +3 \ No newline at end of file diff --git a/crates/sqllogictests/tests/sqllogictests.rs b/crates/sqllogictests/tests/sqllogictests.rs new file mode 100644 index 0000000000..07fb190714 --- /dev/null +++ b/crates/sqllogictests/tests/sqllogictests.rs @@ -0,0 +1,90 @@ +use std::fs; +use std::path::PathBuf; + +use iceberg_test_utils::docker::DockerCompose; +use libtest_mimic::{Arguments, Trial}; +use sqllogictests::schedule::Schedule; +use tokio::runtime::Handle; + +fn main() { + env_logger::init(); + + log::info!("Starting docker compose..."); + let docker = start_docker().unwrap(); + + log::info!("Starting tokio runtime..."); + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + // Parse command line arguments + let args = Arguments::from_args(); + + log::info!("Creating tests..."); + let tests = collect_trials(rt.handle().clone()).unwrap(); + + log::info!("Starting tests..."); + // Run all tests and exit the application appropriatly. + let result = libtest_mimic::run(&args, tests); + + log::info!("Shutting down tokio runtime..."); + drop(rt); + log::info!("Shutting down docker..."); + drop(docker); + + result.exit(); +} + +fn start_docker() -> anyhow::Result { + let docker = DockerCompose::new( + "sqllogictests", + format!("{}/testdata/docker", env!("CARGO_MANIFEST_DIR")), + ); + docker.run(); + Ok(docker) +} + +fn collect_trials(handle: Handle) -> anyhow::Result> { + let schedule_files = collect_schedule_files()?; + log::debug!( + "Found {} schedule files: {:?}", + schedule_files.len(), + &schedule_files + ); + let mut trials = Vec::with_capacity(schedule_files.len()); + for schedule_file in schedule_files { + let h = handle.clone(); + let trial_name = format!( + "Test schedule {}", + schedule_file + .file_name() + .expect("Schedule file should have a name") + .to_string_lossy() + ); + let trial = Trial::test(trial_name, move || { + Ok(h.block_on(run_schedule(schedule_file.clone()))?) + }); + trials.push(trial); + } + Ok(trials) +} + +fn collect_schedule_files() -> anyhow::Result> { + let dir = PathBuf::from(format!("{}/testdata/schedules", env!("CARGO_MANIFEST_DIR"))); + let mut schedule_files = Vec::with_capacity(32); + for entry in fs::read_dir(&dir)? { + let entry = entry?; + let path = entry.path(); + if path.is_file() { + schedule_files.push(fs::canonicalize(dir.join(path))?); + } + } + Ok(schedule_files) +} + +async fn run_schedule(schedule_file: PathBuf) -> anyhow::Result<()> { + let schedule = Schedule::parse(schedule_file).await?; + schedule.run().await?; + Ok(()) +} diff --git a/crates/test_utils/src/docker.rs b/crates/test_utils/src/docker.rs index bde9737b17..6d2de35777 100644 --- a/crates/test_utils/src/docker.rs +++ b/crates/test_utils/src/docker.rs @@ -102,13 +102,11 @@ impl DockerCompose { let ip_result = get_cmd_output(cmd, format!("Get container ip of {container_name}")) .trim() .parse::(); - match ip_result { - Ok(ip) => ip, - Err(e) => { - log::error!("Invalid IP, {e}"); - panic!("Failed to parse IP for {container_name}") - } - } + + ip_result.unwrap_or_else(|e| { + log::error!("Invalid IP, {e}"); + panic!("Failed to parse IP for {container_name}") + }) } } @@ -126,12 +124,12 @@ impl Drop for DockerCompose { "--remove-orphans", ]); - run_command( - cmd, - format!( - "Stopping docker compose in {}, project name: {}", - self.docker_compose_dir, self.project_name - ), - ) + // run_command( + // cmd, + // format!( + // "Stopping docker compose in {}, project name: {}", + // self.docker_compose_dir, self.project_name + // ), + // ) } }