Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions bin/bucketd/src/cli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,15 @@ pub struct IceBucketOpts {
help = "CORS Allow Origin"
)]
pub cors_allow_origin: Option<String>,

#[arg(
short,
long,
default_value = "json",
env = "DATA_FORMAT",
help = "Data serialization format in Snowflake v1 API"
)]
pub data_format: Option<String>,
}

#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum)]
Expand Down
15 changes: 13 additions & 2 deletions bin/bucketd/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ async fn main() {
} else {
None
};
let dbt_serialization_format = opts
.data_format
.clone()
.unwrap_or_else(|| "json".to_string());
let object_store = opts.object_store_backend();

match object_store {
Expand All @@ -53,8 +57,15 @@ async fn main() {
Ok(object_store) => {
tracing::info!("Starting 🧊🪣 IceBucket...");

if let Err(e) =
nexus::run_icebucket(object_store, slatedb_prefix, host, port, allow_origin).await
if let Err(e) = nexus::run_icebucket(
object_store,
slatedb_prefix,
host,
port,
allow_origin,
&dbt_serialization_format,
)
.await
{
tracing::error!("Failed to start IceBucket: {:?}", e);
}
Expand Down
12 changes: 6 additions & 6 deletions crates/control_plane/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,13 +114,14 @@ impl ControlServiceImpl {
pub fn new(
storage_profile_repo: Arc<dyn StorageProfileRepository>,
warehouse_repo: Arc<dyn WarehouseRepository>,
config: Config,
) -> Self {
let df_sessions = Arc::new(RwLock::new(HashMap::new()));
Self {
storage_profile_repo,
warehouse_repo,
df_sessions,
config: Config::default(),
config,
}
}
}
Expand Down Expand Up @@ -341,10 +342,9 @@ impl ControlService for ControlServiceImpl {
.into_iter()
.collect::<Vec<_>>();

let serialization_format = self.config().dbt_serialization_format;
let data_format = self.config().data_format;
// Add columns dbt metadata to each field
convert_record_batches(records, serialization_format)
.context(error::DataFusionQuerySnafu { query })
convert_record_batches(records, data_format).context(error::DataFusionQuerySnafu { query })
}

#[tracing::instrument(level = "debug", skip(self))]
Expand Down Expand Up @@ -587,7 +587,7 @@ mod tests {
fn service() -> ControlServiceImpl {
let storage_repo = Arc::new(InMemoryStorageProfileRepository::default());
let warehouse_repo = Arc::new(InMemoryWarehouseRepository::default());
ControlServiceImpl::new(storage_repo, warehouse_repo)
ControlServiceImpl::new(storage_repo, warehouse_repo, Config::new("json"))
}

fn storage_profile_req() -> StorageProfileCreateRequest {
Expand Down Expand Up @@ -811,7 +811,7 @@ mod tests {
storage_repo: Arc<dyn StorageProfileRepository>,
warehouse_repo: Arc<dyn WarehouseRepository>,
) {
let service = ControlServiceImpl::new(storage_repo, warehouse_repo);
let service = ControlServiceImpl::new(storage_repo, warehouse_repo, Config::new("json"));
service
.create_session("TEST_SESSION".to_string())
.await
Expand Down
56 changes: 26 additions & 30 deletions crates/control_plane/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,42 +31,41 @@ use rusoto_core::{HttpClient, Region};
use rusoto_credential::StaticProvider;
use rusoto_s3::{GetBucketAclOutput, GetBucketAclRequest, S3Client as ExternalS3Client, S3};
use snafu::ResultExt;
use std::fmt::Display;
use std::fmt;
use std::sync::Arc;
use std::{env, fmt};

pub struct Config {
pub dbt_serialization_format: SerializationFormat,
pub data_format: DataFormat,
}

impl Default for Config {
fn default() -> Self {
impl Config {
#[must_use]
pub fn new(data_format: &str) -> Self {
Self {
dbt_serialization_format: SerializationFormat::new(),
data_format: DataFormat::from_str(data_format),
}
}
}
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum SerializationFormat {
pub enum DataFormat {
Arrow,
Json,
}

impl Display for SerializationFormat {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Arrow => write!(f, "arrow"),
Self::Json => write!(f, "json"),
impl DataFormat {
fn from_str(value: &str) -> Self {
match value {
"arrow" => Self::Arrow,
_ => Self::Json,
}
}
}

impl SerializationFormat {
fn new() -> Self {
let var = env::var("DBT_SERIALIZATION_FORMAT").unwrap_or_else(|_| "json".to_string());
match var.to_lowercase().as_str() {
"arrow" => Self::Arrow,
_ => Self::Json,
impl fmt::Display for DataFormat {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Arrow => write!(f, "arrow"),
Self::Json => write!(f, "json"),
}
}
}
Expand Down Expand Up @@ -143,7 +142,7 @@ pub fn first_non_empty_type(union_array: &UnionArray) -> Option<(DataType, Array

pub fn convert_record_batches(
records: Vec<RecordBatch>,
serialization_format: SerializationFormat,
data_format: DataFormat,
) -> DataFusionResult<(Vec<RecordBatch>, Vec<ColumnInfo>)> {
let mut converted_batches = Vec::new();
let column_infos = ColumnInfo::from_batch(&records);
Expand Down Expand Up @@ -173,8 +172,7 @@ pub fn convert_record_batches(
}
}
DataType::Timestamp(unit, _) => {
let converted_column =
convert_timestamp_to_struct(column, *unit, serialization_format);
let converted_column = convert_timestamp_to_struct(column, *unit, data_format);
fields.push(
Field::new(
field.name(),
Expand Down Expand Up @@ -218,10 +216,10 @@ macro_rules! downcast_and_iter {
fn convert_timestamp_to_struct(
column: &ArrayRef,
unit: TimeUnit,
ser: SerializationFormat,
data_format: DataFormat,
) -> ArrayRef {
match ser {
SerializationFormat::Arrow => {
match data_format {
DataFormat::Arrow => {
let timestamps: Vec<_> = match unit {
TimeUnit::Second => downcast_and_iter!(column, TimestampSecondArray).collect(),
TimeUnit::Millisecond => {
Expand All @@ -236,7 +234,7 @@ fn convert_timestamp_to_struct(
};
Arc::new(Int64Array::from(timestamps)) as ArrayRef
}
SerializationFormat::Json => {
DataFormat::Json => {
let timestamps: Vec<_> = match unit {
TimeUnit::Second => downcast_and_iter!(column, TimestampSecondArray)
.map(|x| {
Expand Down Expand Up @@ -344,8 +342,7 @@ mod tests {
Arc::new(TimestampNanosecondArray::from(values)) as ArrayRef
}
};
let result =
convert_timestamp_to_struct(&timestamp_array, *unit, SerializationFormat::Json);
let result = convert_timestamp_to_struct(&timestamp_array, *unit, DataFormat::Json);
let string_array = result.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(string_array.len(), 2);
assert_eq!(string_array.value(0), *expected);
Expand All @@ -372,7 +369,7 @@ mod tests {
let batch = RecordBatch::try_new(schema, vec![int_array, timestamp_array]).unwrap();
let records = vec![batch];
let (converted_batches, column_infos) =
convert_record_batches(records.clone(), SerializationFormat::Json).unwrap();
convert_record_batches(records.clone(), DataFormat::Json).unwrap();

let converted_batch = &converted_batches[0];
assert_eq!(converted_batches.len(), 1);
Expand All @@ -393,8 +390,7 @@ mod tests {
assert_eq!(column_infos[1].name, "timestamp_col");
assert_eq!(column_infos[1].r#type, "timestamp_ntz");

let (converted_batches, _) =
convert_record_batches(records, SerializationFormat::Arrow).unwrap();
let (converted_batches, _) = convert_record_batches(records, DataFormat::Arrow).unwrap();
let converted_batch = &converted_batches[0];
let converted_timestamp_array = converted_batch
.column(1)
Expand Down
10 changes: 5 additions & 5 deletions crates/nexus/src/http/dbt/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ use axum::Json;
use base64;
use base64::engine::general_purpose::STANDARD as engine_base64;
use base64::prelude::*;
use control_plane::utils::SerializationFormat;
use control_plane::utils::DataFormat;
use flate2::read::GzDecoder;
use regex::Regex;
use snafu::ResultExt;
Expand Down Expand Up @@ -164,19 +164,19 @@ pub async fn query(
records_to_json_string(&records)?.as_str()
);

let serialization_format = state.control_svc.config().dbt_serialization_format;
let data_format = state.control_svc.config().data_format;
let json_resp = Json(JsonResponse {
data: Option::from(ResponseData {
row_type: columns.into_iter().map(Into::into).collect(),
query_result_format: Some(serialization_format.to_string()),
row_set: if serialization_format == SerializationFormat::Json {
query_result_format: Some(data_format.to_string()),
row_set: if data_format == DataFormat::Json {
Option::from(ResponseData::rows_to_vec(
records_to_json_string(&records)?.as_str(),
)?)
} else {
None
},
row_set_base_64: if serialization_format == SerializationFormat::Arrow {
row_set_base_64: if data_format == DataFormat::Arrow {
Option::from(records_to_arrow_string(&records)?)
} else {
None
Expand Down
8 changes: 7 additions & 1 deletion crates/nexus/src/http/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ mod tests {
use catalog::service::CatalogImpl;
use control_plane::repository::{StorageProfileRepositoryDb, WarehouseRepositoryDb};
use control_plane::service::ControlServiceImpl;
use control_plane::utils::Config;
use http_body_util::BodyExt;
// for `collect`
use object_store::{memory::InMemory, path::Path, ObjectStore};
Expand Down Expand Up @@ -144,7 +145,12 @@ mod tests {
let control_svc = {
let storage_profile_repo = StorageProfileRepositoryDb::new(db.clone());
let warehouse_repo = WarehouseRepositoryDb::new(db.clone());
ControlServiceImpl::new(Arc::new(storage_profile_repo), Arc::new(warehouse_repo))
let config = Config::new("json");
ControlServiceImpl::new(
Arc::new(storage_profile_repo),
Arc::new(warehouse_repo),
config,
)
};

let catalog_svc = {
Expand Down
9 changes: 8 additions & 1 deletion crates/nexus/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ use catalog::repository::{DatabaseRepositoryDb, TableRepositoryDb};
use catalog::service::CatalogImpl;
use control_plane::repository::{StorageProfileRepositoryDb, WarehouseRepositoryDb};
use control_plane::service::ControlServiceImpl;
use control_plane::utils::Config as ControlServiceConfig;
use http_body_util::BodyExt;
use object_store::{path::Path, ObjectStore};
use slatedb::config::DbOptions;
Expand All @@ -51,6 +52,7 @@ pub async fn run_icebucket(
host: String,
port: u16,
allow_origin: Option<String>,
data_format: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let db = {
let options = DbOptions::default();
Expand All @@ -64,7 +66,12 @@ pub async fn run_icebucket(
let control_svc = {
let storage_profile_repo = StorageProfileRepositoryDb::new(db.clone());
let warehouse_repo = WarehouseRepositoryDb::new(db.clone());
ControlServiceImpl::new(Arc::new(storage_profile_repo), Arc::new(warehouse_repo))
let config = ControlServiceConfig::new(data_format);
ControlServiceImpl::new(
Arc::new(storage_profile_repo),
Arc::new(warehouse_repo),
config,
)
};
let control_svc = Arc::new(control_svc);

Expand Down
Loading