From 83c243f828aa932217e1940c9d96f60e64aef717 Mon Sep 17 00:00:00 2001 From: osipovartem Date: Thu, 20 Feb 2025 01:36:24 +0300 Subject: [PATCH 1/2] Move dbt_ser_fmt to cli --- bin/bucketd/src/cli.rs | 9 +++++++++ bin/bucketd/src/main.rs | 15 ++++++++++++-- crates/control_plane/src/service.rs | 7 ++++--- crates/control_plane/src/utils.rs | 31 ++++++++++++++--------------- crates/nexus/src/http/router.rs | 8 +++++++- crates/nexus/src/lib.rs | 9 ++++++++- 6 files changed, 56 insertions(+), 23 deletions(-) diff --git a/bin/bucketd/src/cli.rs b/bin/bucketd/src/cli.rs index 05a128a37..f446c2a2c 100644 --- a/bin/bucketd/src/cli.rs +++ b/bin/bucketd/src/cli.rs @@ -125,6 +125,15 @@ pub struct IceBucketOpts { help = "CORS Allow Origin" )] pub cors_allow_origin: Option, + + #[arg( + short, + long, + default_value = "json", + env = "DBT_SERIALIZATION_FORMAT", + help = "Serialization format for dbt endpoints" + )] + pub dbt_serialization_format: Option, } #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] diff --git a/bin/bucketd/src/main.rs b/bin/bucketd/src/main.rs index c80833eb2..6c62ffc7f 100644 --- a/bin/bucketd/src/main.rs +++ b/bin/bucketd/src/main.rs @@ -43,6 +43,10 @@ async fn main() { } else { None }; + let dbt_serialization_format = opts + .dbt_serialization_format + .clone() + .unwrap_or_else(|| "json".to_string()); let object_store = opts.object_store_backend(); match object_store { @@ -53,8 +57,15 @@ async fn main() { Ok(object_store) => { tracing::info!("Starting 🧊🪣 IceBucket..."); - if let Err(e) = - nexus::run_icebucket(object_store, slatedb_prefix, host, port, allow_origin).await + if let Err(e) = nexus::run_icebucket( + object_store, + slatedb_prefix, + host, + port, + allow_origin, + &dbt_serialization_format, + ) + .await { tracing::error!("Failed to start IceBucket: {:?}", e); } diff --git a/crates/control_plane/src/service.rs b/crates/control_plane/src/service.rs index fe2dc84ea..18916e2f5 100644 --- a/crates/control_plane/src/service.rs +++ b/crates/control_plane/src/service.rs @@ -114,13 +114,14 @@ impl ControlServiceImpl { pub fn new( storage_profile_repo: Arc, warehouse_repo: Arc, + config: Config, ) -> Self { let df_sessions = Arc::new(RwLock::new(HashMap::new())); Self { storage_profile_repo, warehouse_repo, df_sessions, - config: Config::default(), + config, } } } @@ -587,7 +588,7 @@ mod tests { fn service() -> ControlServiceImpl { let storage_repo = Arc::new(InMemoryStorageProfileRepository::default()); let warehouse_repo = Arc::new(InMemoryWarehouseRepository::default()); - ControlServiceImpl::new(storage_repo, warehouse_repo) + ControlServiceImpl::new(storage_repo, warehouse_repo, Config::new("json")) } fn storage_profile_req() -> StorageProfileCreateRequest { @@ -811,7 +812,7 @@ mod tests { storage_repo: Arc, warehouse_repo: Arc, ) { - let service = ControlServiceImpl::new(storage_repo, warehouse_repo); + let service = ControlServiceImpl::new(storage_repo, warehouse_repo, Config::new("json")); service .create_session("TEST_SESSION".to_string()) .await diff --git a/crates/control_plane/src/utils.rs b/crates/control_plane/src/utils.rs index 657d83025..30a5bd4ff 100644 --- a/crates/control_plane/src/utils.rs +++ b/crates/control_plane/src/utils.rs @@ -31,18 +31,18 @@ use rusoto_core::{HttpClient, Region}; use rusoto_credential::StaticProvider; use rusoto_s3::{GetBucketAclOutput, GetBucketAclRequest, S3Client as ExternalS3Client, S3}; use snafu::ResultExt; -use std::fmt::Display; +use std::fmt; use std::sync::Arc; -use std::{env, fmt}; pub struct Config { pub dbt_serialization_format: SerializationFormat, } -impl Default for Config { - fn default() -> Self { +impl Config { + #[must_use] + pub fn new(serialization_format: &str) -> Self { Self { - dbt_serialization_format: SerializationFormat::new(), + dbt_serialization_format: SerializationFormat::from_str(serialization_format), } } } @@ -52,21 +52,20 @@ pub enum SerializationFormat { Json, } -impl Display for SerializationFormat { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Arrow => write!(f, "arrow"), - Self::Json => write!(f, "json"), +impl SerializationFormat { + fn from_str(value: &str) -> Self { + match value { + "arrow" => Self::Arrow, + _ => Self::Json, } } } -impl SerializationFormat { - fn new() -> Self { - let var = env::var("DBT_SERIALIZATION_FORMAT").unwrap_or_else(|_| "json".to_string()); - match var.to_lowercase().as_str() { - "arrow" => Self::Arrow, - _ => Self::Json, +impl fmt::Display for SerializationFormat { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Arrow => write!(f, "arrow"), + Self::Json => write!(f, "json"), } } } diff --git a/crates/nexus/src/http/router.rs b/crates/nexus/src/http/router.rs index 7c13f872f..405f079c4 100644 --- a/crates/nexus/src/http/router.rs +++ b/crates/nexus/src/http/router.rs @@ -114,6 +114,7 @@ mod tests { use catalog::service::CatalogImpl; use control_plane::repository::{StorageProfileRepositoryDb, WarehouseRepositoryDb}; use control_plane::service::ControlServiceImpl; + use control_plane::utils::Config; use http_body_util::BodyExt; // for `collect` use object_store::{memory::InMemory, path::Path, ObjectStore}; @@ -144,7 +145,12 @@ mod tests { let control_svc = { let storage_profile_repo = StorageProfileRepositoryDb::new(db.clone()); let warehouse_repo = WarehouseRepositoryDb::new(db.clone()); - ControlServiceImpl::new(Arc::new(storage_profile_repo), Arc::new(warehouse_repo)) + let config = Config::new("json"); + ControlServiceImpl::new( + Arc::new(storage_profile_repo), + Arc::new(warehouse_repo), + config, + ) }; let catalog_svc = { diff --git a/crates/nexus/src/lib.rs b/crates/nexus/src/lib.rs index 17c92ee5e..426e82148 100644 --- a/crates/nexus/src/lib.rs +++ b/crates/nexus/src/lib.rs @@ -26,6 +26,7 @@ use catalog::repository::{DatabaseRepositoryDb, TableRepositoryDb}; use catalog::service::CatalogImpl; use control_plane::repository::{StorageProfileRepositoryDb, WarehouseRepositoryDb}; use control_plane::service::ControlServiceImpl; +use control_plane::utils::Config as ControlServiceConfig; use http_body_util::BodyExt; use object_store::{path::Path, ObjectStore}; use slatedb::config::DbOptions; @@ -51,6 +52,7 @@ pub async fn run_icebucket( host: String, port: u16, allow_origin: Option, + dbt_serialization_format: &str, ) -> Result<(), Box> { let db = { let options = DbOptions::default(); @@ -64,7 +66,12 @@ pub async fn run_icebucket( let control_svc = { let storage_profile_repo = StorageProfileRepositoryDb::new(db.clone()); let warehouse_repo = WarehouseRepositoryDb::new(db.clone()); - ControlServiceImpl::new(Arc::new(storage_profile_repo), Arc::new(warehouse_repo)) + let config = ControlServiceConfig::new(dbt_serialization_format); + ControlServiceImpl::new( + Arc::new(storage_profile_repo), + Arc::new(warehouse_repo), + config, + ) }; let control_svc = Arc::new(control_svc); From fee0ab14718af5d175a43e7a90b7117f665343c5 Mon Sep 17 00:00:00 2001 From: osipovartem Date: Thu, 20 Feb 2025 10:14:13 +0300 Subject: [PATCH 2/2] Rename the field --- bin/bucketd/src/cli.rs | 6 ++--- bin/bucketd/src/main.rs | 2 +- crates/control_plane/src/service.rs | 5 ++-- crates/control_plane/src/utils.rs | 33 ++++++++++++--------------- crates/nexus/src/http/dbt/handlers.rs | 10 ++++---- crates/nexus/src/lib.rs | 4 ++-- 6 files changed, 28 insertions(+), 32 deletions(-) diff --git a/bin/bucketd/src/cli.rs b/bin/bucketd/src/cli.rs index f446c2a2c..f65433c11 100644 --- a/bin/bucketd/src/cli.rs +++ b/bin/bucketd/src/cli.rs @@ -130,10 +130,10 @@ pub struct IceBucketOpts { short, long, default_value = "json", - env = "DBT_SERIALIZATION_FORMAT", - help = "Serialization format for dbt endpoints" + env = "DATA_FORMAT", + help = "Data serialization format in Snowflake v1 API" )] - pub dbt_serialization_format: Option, + pub data_format: Option, } #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)] diff --git a/bin/bucketd/src/main.rs b/bin/bucketd/src/main.rs index 6c62ffc7f..73ff3fc5d 100644 --- a/bin/bucketd/src/main.rs +++ b/bin/bucketd/src/main.rs @@ -44,7 +44,7 @@ async fn main() { None }; let dbt_serialization_format = opts - .dbt_serialization_format + .data_format .clone() .unwrap_or_else(|| "json".to_string()); let object_store = opts.object_store_backend(); diff --git a/crates/control_plane/src/service.rs b/crates/control_plane/src/service.rs index 18916e2f5..a8fe34302 100644 --- a/crates/control_plane/src/service.rs +++ b/crates/control_plane/src/service.rs @@ -342,10 +342,9 @@ impl ControlService for ControlServiceImpl { .into_iter() .collect::>(); - let serialization_format = self.config().dbt_serialization_format; + let data_format = self.config().data_format; // Add columns dbt metadata to each field - convert_record_batches(records, serialization_format) - .context(error::DataFusionQuerySnafu { query }) + convert_record_batches(records, data_format).context(error::DataFusionQuerySnafu { query }) } #[tracing::instrument(level = "debug", skip(self))] diff --git a/crates/control_plane/src/utils.rs b/crates/control_plane/src/utils.rs index 30a5bd4ff..b8bfb6010 100644 --- a/crates/control_plane/src/utils.rs +++ b/crates/control_plane/src/utils.rs @@ -35,24 +35,24 @@ use std::fmt; use std::sync::Arc; pub struct Config { - pub dbt_serialization_format: SerializationFormat, + pub data_format: DataFormat, } impl Config { #[must_use] - pub fn new(serialization_format: &str) -> Self { + pub fn new(data_format: &str) -> Self { Self { - dbt_serialization_format: SerializationFormat::from_str(serialization_format), + data_format: DataFormat::from_str(data_format), } } } #[derive(Copy, Clone, PartialEq, Eq)] -pub enum SerializationFormat { +pub enum DataFormat { Arrow, Json, } -impl SerializationFormat { +impl DataFormat { fn from_str(value: &str) -> Self { match value { "arrow" => Self::Arrow, @@ -61,7 +61,7 @@ impl SerializationFormat { } } -impl fmt::Display for SerializationFormat { +impl fmt::Display for DataFormat { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Arrow => write!(f, "arrow"), @@ -142,7 +142,7 @@ pub fn first_non_empty_type(union_array: &UnionArray) -> Option<(DataType, Array pub fn convert_record_batches( records: Vec, - serialization_format: SerializationFormat, + data_format: DataFormat, ) -> DataFusionResult<(Vec, Vec)> { let mut converted_batches = Vec::new(); let column_infos = ColumnInfo::from_batch(&records); @@ -172,8 +172,7 @@ pub fn convert_record_batches( } } DataType::Timestamp(unit, _) => { - let converted_column = - convert_timestamp_to_struct(column, *unit, serialization_format); + let converted_column = convert_timestamp_to_struct(column, *unit, data_format); fields.push( Field::new( field.name(), @@ -217,10 +216,10 @@ macro_rules! downcast_and_iter { fn convert_timestamp_to_struct( column: &ArrayRef, unit: TimeUnit, - ser: SerializationFormat, + data_format: DataFormat, ) -> ArrayRef { - match ser { - SerializationFormat::Arrow => { + match data_format { + DataFormat::Arrow => { let timestamps: Vec<_> = match unit { TimeUnit::Second => downcast_and_iter!(column, TimestampSecondArray).collect(), TimeUnit::Millisecond => { @@ -235,7 +234,7 @@ fn convert_timestamp_to_struct( }; Arc::new(Int64Array::from(timestamps)) as ArrayRef } - SerializationFormat::Json => { + DataFormat::Json => { let timestamps: Vec<_> = match unit { TimeUnit::Second => downcast_and_iter!(column, TimestampSecondArray) .map(|x| { @@ -343,8 +342,7 @@ mod tests { Arc::new(TimestampNanosecondArray::from(values)) as ArrayRef } }; - let result = - convert_timestamp_to_struct(×tamp_array, *unit, SerializationFormat::Json); + let result = convert_timestamp_to_struct(×tamp_array, *unit, DataFormat::Json); let string_array = result.as_any().downcast_ref::().unwrap(); assert_eq!(string_array.len(), 2); assert_eq!(string_array.value(0), *expected); @@ -371,7 +369,7 @@ mod tests { let batch = RecordBatch::try_new(schema, vec![int_array, timestamp_array]).unwrap(); let records = vec![batch]; let (converted_batches, column_infos) = - convert_record_batches(records.clone(), SerializationFormat::Json).unwrap(); + convert_record_batches(records.clone(), DataFormat::Json).unwrap(); let converted_batch = &converted_batches[0]; assert_eq!(converted_batches.len(), 1); @@ -392,8 +390,7 @@ mod tests { assert_eq!(column_infos[1].name, "timestamp_col"); assert_eq!(column_infos[1].r#type, "timestamp_ntz"); - let (converted_batches, _) = - convert_record_batches(records, SerializationFormat::Arrow).unwrap(); + let (converted_batches, _) = convert_record_batches(records, DataFormat::Arrow).unwrap(); let converted_batch = &converted_batches[0]; let converted_timestamp_array = converted_batch .column(1) diff --git a/crates/nexus/src/http/dbt/handlers.rs b/crates/nexus/src/http/dbt/handlers.rs index 3e6c457c6..fce274be8 100644 --- a/crates/nexus/src/http/dbt/handlers.rs +++ b/crates/nexus/src/http/dbt/handlers.rs @@ -34,7 +34,7 @@ use axum::Json; use base64; use base64::engine::general_purpose::STANDARD as engine_base64; use base64::prelude::*; -use control_plane::utils::SerializationFormat; +use control_plane::utils::DataFormat; use flate2::read::GzDecoder; use regex::Regex; use snafu::ResultExt; @@ -164,19 +164,19 @@ pub async fn query( records_to_json_string(&records)?.as_str() ); - let serialization_format = state.control_svc.config().dbt_serialization_format; + let data_format = state.control_svc.config().data_format; let json_resp = Json(JsonResponse { data: Option::from(ResponseData { row_type: columns.into_iter().map(Into::into).collect(), - query_result_format: Some(serialization_format.to_string()), - row_set: if serialization_format == SerializationFormat::Json { + query_result_format: Some(data_format.to_string()), + row_set: if data_format == DataFormat::Json { Option::from(ResponseData::rows_to_vec( records_to_json_string(&records)?.as_str(), )?) } else { None }, - row_set_base_64: if serialization_format == SerializationFormat::Arrow { + row_set_base_64: if data_format == DataFormat::Arrow { Option::from(records_to_arrow_string(&records)?) } else { None diff --git a/crates/nexus/src/lib.rs b/crates/nexus/src/lib.rs index 426e82148..4729cf5c8 100644 --- a/crates/nexus/src/lib.rs +++ b/crates/nexus/src/lib.rs @@ -52,7 +52,7 @@ pub async fn run_icebucket( host: String, port: u16, allow_origin: Option, - dbt_serialization_format: &str, + data_format: &str, ) -> Result<(), Box> { let db = { let options = DbOptions::default(); @@ -66,7 +66,7 @@ pub async fn run_icebucket( let control_svc = { let storage_profile_repo = StorageProfileRepositoryDb::new(db.clone()); let warehouse_repo = WarehouseRepositoryDb::new(db.clone()); - let config = ControlServiceConfig::new(dbt_serialization_format); + let config = ControlServiceConfig::new(data_format); ControlServiceImpl::new( Arc::new(storage_profile_repo), Arc::new(warehouse_repo),