diff --git a/src/static_schema.rs b/src/static_schema.rs index 5b1a5cada..286ec65ad 100644 --- a/src/static_schema.rs +++ b/src/static_schema.rs @@ -18,12 +18,15 @@ use crate::event::DEFAULT_TIMESTAMP_KEY; use crate::utils::arrow::get_field; -use anyhow::{anyhow, Error as AnyError}; use serde::{Deserialize, Serialize}; use std::str; use arrow_schema::{DataType, Field, Schema, TimeUnit}; -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{HashMap, HashSet}, + sync::Arc, +}; + #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct StaticSchema { fields: Vec, @@ -54,13 +57,12 @@ pub struct Fields { } #[derive(Default, Debug, Clone, PartialEq, Serialize, Deserialize)] - pub struct Metadata {} pub fn convert_static_schema_to_arrow_schema( static_schema: StaticSchema, time_partition: &str, custom_partition: Option<&String>, -) -> Result, AnyError> { +) -> Result, StaticSchemaError> { let mut parsed_schema = ParsedSchema { fields: Vec::new(), metadata: HashMap::new(), @@ -83,11 +85,17 @@ pub fn convert_static_schema_to_arrow_schema( for partition in &custom_partition_list { if !custom_partition_exists.contains_key(*partition) { - return Err(anyhow!("custom partition field {partition} does not exist in the schema for the static schema logstream")); + return Err(StaticSchemaError::MissingCustomPartition( + partition.to_string(), + )); } } } + + let mut existing_field_names: HashSet = HashSet::new(); + for mut field in static_schema.fields { + validate_field_names(&field.name, &mut existing_field_names)?; if !time_partition.is_empty() && field.name == time_partition { time_partition_exists = true; field.data_type = "datetime".to_string(); @@ -127,18 +135,16 @@ pub fn convert_static_schema_to_arrow_schema( parsed_schema.fields.push(parsed_field); } if !time_partition.is_empty() && !time_partition_exists { - return Err(anyhow! { - format!( - "time partition field {time_partition} does not exist in the schema for the static schema logstream" - ), - }); + return Err(StaticSchemaError::MissingTimePartition( + time_partition.to_string(), + )); } add_parseable_fields_to_static_schema(parsed_schema) } fn add_parseable_fields_to_static_schema( parsed_schema: ParsedSchema, -) -> Result, AnyError> { +) -> Result, StaticSchemaError> { let mut schema: Vec> = Vec::new(); for field in parsed_schema.fields.iter() { let field = Field::new(field.name.clone(), field.data_type.clone(), field.nullable); @@ -146,10 +152,7 @@ fn add_parseable_fields_to_static_schema( } if get_field(&schema, DEFAULT_TIMESTAMP_KEY).is_some() { - return Err(anyhow!( - "field {} is a reserved field", - DEFAULT_TIMESTAMP_KEY - )); + return Err(StaticSchemaError::ReservedKey(DEFAULT_TIMESTAMP_KEY)); }; // add the p_timestamp field to the event schema to the 0th index @@ -176,3 +179,57 @@ fn default_dict_id() -> i64 { fn default_dict_is_ordered() -> bool { false } + +fn validate_field_names( + field_name: &str, + existing_fields: &mut HashSet, +) -> Result<(), StaticSchemaError> { + if field_name.is_empty() { + return Err(StaticSchemaError::EmptyFieldName); + } + + if !existing_fields.insert(field_name.to_string()) { + return Err(StaticSchemaError::DuplicateField(field_name.to_string())); + } + + Ok(()) +} + +#[derive(Debug, thiserror::Error)] +pub enum StaticSchemaError { + #[error( + "custom partition field {0} does not exist in the schema for the static schema logstream" + )] + MissingCustomPartition(String), + + #[error( + "time partition field {0} does not exist in the schema for the static schema logstream" + )] + MissingTimePartition(String), + + #[error("field {0:?} is a reserved field")] + ReservedKey(&'static str), + + #[error("field name cannot be empty")] + EmptyFieldName, + + #[error("duplicate field name: {0}")] + DuplicateField(String), +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn empty_field_names() { + let mut existing_field_names: HashSet = HashSet::new(); + assert!(validate_field_names("", &mut existing_field_names).is_err()); + } + + #[test] + fn duplicate_field_names() { + let mut existing_field_names: HashSet = HashSet::new(); + let _ = validate_field_names("test_field", &mut existing_field_names); + assert!(validate_field_names("test_field", &mut existing_field_names).is_err()); + } +}