diff --git a/Cargo.toml b/Cargo.toml index dccc6bdf19..c6a9e15e45 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,6 +31,7 @@ rust-version = "1.75.0" [workspace.dependencies] anyhow = "1.0.72" apache-avro = "0.16" +array-init = "2" arrow-arith = { version = ">=46" } arrow-array = { version = ">=46" } arrow-schema = { version = ">=46" } @@ -41,6 +42,7 @@ chrono = "0.4" derive_builder = "0.20.0" either = "1" env_logger = "0.11.0" +fnv = "1" futures = "0.3" iceberg = { version = "0.2.0", path = "./crates/iceberg" } iceberg-catalog-rest = { version = "0.2.0", path = "./crates/catalog/rest" } diff --git a/crates/iceberg/Cargo.toml b/crates/iceberg/Cargo.toml index 32288ee815..eb60412bb1 100644 --- a/crates/iceberg/Cargo.toml +++ b/crates/iceberg/Cargo.toml @@ -31,6 +31,7 @@ keywords = ["iceberg"] [dependencies] anyhow = { workspace = true } apache-avro = { workspace = true } +array-init = { workspace = true } arrow-arith = { workspace = true } arrow-array = { workspace = true } arrow-schema = { workspace = true } @@ -40,6 +41,7 @@ bitvec = { workspace = true } chrono = { workspace = true } derive_builder = { workspace = true } either = { workspace = true } +fnv = { workspace = true } futures = { workspace = true } itertools = { workspace = true } lazy_static = { workspace = true } diff --git a/crates/iceberg/src/expr/mod.rs b/crates/iceberg/src/expr/mod.rs index c08c836c32..567cf7e913 100644 --- a/crates/iceberg/src/expr/mod.rs +++ b/crates/iceberg/src/expr/mod.rs @@ -23,6 +23,8 @@ use std::fmt::{Display, Formatter}; pub use term::*; mod predicate; + +use crate::spec::SchemaRef; pub use predicate::*; /// Predicate operators used in expressions. @@ -147,6 +149,14 @@ impl PredicateOperator { } } +/// Bind expression to a schema. +pub trait Bind { + /// The type of the bound result. + type Bound; + /// Bind an expression to a schema. + fn bind(self, schema: SchemaRef, case_sensitive: bool) -> crate::Result; +} + #[cfg(test)] mod tests { use crate::expr::PredicateOperator; diff --git a/crates/iceberg/src/expr/predicate.rs b/crates/iceberg/src/expr/predicate.rs index 66a395624b..4ab9aae30f 100644 --- a/crates/iceberg/src/expr/predicate.rs +++ b/crates/iceberg/src/expr/predicate.rs @@ -19,13 +19,18 @@ //! Predicate expressions are used to filter data, and evaluates to a boolean value. For example, //! `a > 10` is a predicate expression, and it evaluates to `true` if `a` is greater than `10`, -use crate::expr::{BoundReference, PredicateOperator, Reference}; -use crate::spec::Datum; -use itertools::Itertools; -use std::collections::HashSet; use std::fmt::{Debug, Display, Formatter}; use std::ops::Not; +use array_init::array_init; +use fnv::FnvHashSet; +use itertools::Itertools; + +use crate::error::Result; +use crate::expr::{Bind, BoundReference, PredicateOperator, Reference}; +use crate::spec::{Datum, SchemaRef}; +use crate::{Error, ErrorKind}; + /// Logical expression, such as `AND`, `OR`, `NOT`. #[derive(PartialEq)] pub struct LogicalExpression { @@ -55,6 +60,24 @@ impl LogicalExpression { } } +impl Bind for LogicalExpression +where + T::Bound: Sized, +{ + type Bound = LogicalExpression; + + fn bind(self, schema: SchemaRef, case_sensitive: bool) -> Result { + let mut outputs: [Option>; N] = array_init(|_| None); + for (i, input) in self.inputs.into_iter().enumerate() { + outputs[i] = Some(Box::new(input.bind(schema.clone(), case_sensitive)?)); + } + + // It's safe to use `unwrap` here since they are all `Some`. + let bound_inputs = array_init::from_iter(outputs.into_iter().map(Option::unwrap)).unwrap(); + Ok(LogicalExpression::new(bound_inputs)) + } +} + /// Unary predicate, for example, `a IS NULL`. #[derive(PartialEq)] pub struct UnaryExpression { @@ -79,6 +102,15 @@ impl Display for UnaryExpression { } } +impl Bind for UnaryExpression { + type Bound = UnaryExpression; + + fn bind(self, schema: SchemaRef, case_sensitive: bool) -> Result { + let bound_term = self.term.bind(schema, case_sensitive)?; + Ok(UnaryExpression::new(self.op, bound_term)) + } +} + impl UnaryExpression { pub(crate) fn new(op: PredicateOperator, term: T) -> Self { debug_assert!(op.is_unary()); @@ -120,6 +152,15 @@ impl Display for BinaryExpression { } } +impl Bind for BinaryExpression { + type Bound = BinaryExpression; + + fn bind(self, schema: SchemaRef, case_sensitive: bool) -> Result { + let bound_term = self.term.bind(schema.clone(), case_sensitive)?; + Ok(BinaryExpression::new(self.op, bound_term, self.literal)) + } +} + /// Set predicates, for example, `a in (1, 2, 3)`. #[derive(PartialEq)] pub struct SetExpression { @@ -128,7 +169,7 @@ pub struct SetExpression { /// Term of this predicate, for example, `a` in `a in (1, 2, 3)`. term: T, /// Literals of this predicate, for example, `(1, 2, 3)` in `a in (1, 2, 3)`. - literals: HashSet, + literals: FnvHashSet, } impl Debug for SetExpression { @@ -141,12 +182,22 @@ impl Debug for SetExpression { } } -impl SetExpression { - pub(crate) fn new(op: PredicateOperator, term: T, literals: HashSet) -> Self { +impl SetExpression { + pub(crate) fn new(op: PredicateOperator, term: T, literals: FnvHashSet) -> Self { + debug_assert!(op.is_set()); Self { op, term, literals } } } +impl Bind for SetExpression { + type Bound = SetExpression; + + fn bind(self, schema: SchemaRef, case_sensitive: bool) -> Result { + let bound_term = self.term.bind(schema.clone(), case_sensitive)?; + Ok(SetExpression::new(self.op, bound_term, self.literals)) + } +} + impl Display for SetExpression { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let mut literal_strs = self.literals.iter().map(|l| format!("{}", l)); @@ -172,6 +223,146 @@ pub enum Predicate { Set(SetExpression), } +impl Bind for Predicate { + type Bound = BoundPredicate; + + fn bind(self, schema: SchemaRef, case_sensitive: bool) -> Result { + match self { + Predicate::And(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + + let [left, right] = bound_expr.inputs; + Ok(match (left, right) { + (_, r) if matches!(&*r, &BoundPredicate::AlwaysFalse) => { + BoundPredicate::AlwaysFalse + } + (l, _) if matches!(&*l, &BoundPredicate::AlwaysFalse) => { + BoundPredicate::AlwaysFalse + } + (left, r) if matches!(&*r, &BoundPredicate::AlwaysTrue) => *left, + (l, right) if matches!(&*l, &BoundPredicate::AlwaysTrue) => *right, + (left, right) => BoundPredicate::And(LogicalExpression::new([left, right])), + }) + } + Predicate::Not(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + let [inner] = bound_expr.inputs; + Ok(match inner { + e if matches!(&*e, &BoundPredicate::AlwaysTrue) => BoundPredicate::AlwaysFalse, + e if matches!(&*e, &BoundPredicate::AlwaysFalse) => BoundPredicate::AlwaysTrue, + e => BoundPredicate::Not(LogicalExpression::new([e])), + }) + } + Predicate::Or(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + let [left, right] = bound_expr.inputs; + Ok(match (left, right) { + (l, r) + if matches!(&*r, &BoundPredicate::AlwaysTrue) + || matches!(&*l, &BoundPredicate::AlwaysTrue) => + { + BoundPredicate::AlwaysTrue + } + (left, r) if matches!(&*r, &BoundPredicate::AlwaysFalse) => *left, + (l, right) if matches!(&*l, &BoundPredicate::AlwaysFalse) => *right, + (left, right) => BoundPredicate::Or(LogicalExpression::new([left, right])), + }) + } + Predicate::Unary(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + + match &bound_expr.op { + &PredicateOperator::IsNull => { + if bound_expr.term.field().required { + return Ok(BoundPredicate::AlwaysFalse); + } + } + &PredicateOperator::NotNull => { + if bound_expr.term.field().required { + return Ok(BoundPredicate::AlwaysTrue); + } + } + &PredicateOperator::IsNan | &PredicateOperator::NotNan => { + if !bound_expr.term.field().field_type.is_floating_type() { + return Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Expecting floating point type, but found {}", + bound_expr.term.field().field_type + ), + )); + } + } + op => { + return Err(Error::new( + ErrorKind::Unexpected, + format!("Expecting unary operator, but found {op}"), + )) + } + } + + Ok(BoundPredicate::Unary(bound_expr)) + } + Predicate::Binary(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + let bound_literal = bound_expr.literal.to(&bound_expr.term.field().field_type)?; + Ok(BoundPredicate::Binary(BinaryExpression::new( + bound_expr.op, + bound_expr.term, + bound_literal, + ))) + } + Predicate::Set(expr) => { + let bound_expr = expr.bind(schema, case_sensitive)?; + let bound_literals = bound_expr + .literals + .into_iter() + .map(|l| l.to(&bound_expr.term.field().field_type)) + .collect::>>()?; + + match &bound_expr.op { + &PredicateOperator::In => { + if bound_literals.is_empty() { + return Ok(BoundPredicate::AlwaysFalse); + } + if bound_literals.len() == 1 { + return Ok(BoundPredicate::Binary(BinaryExpression::new( + PredicateOperator::Eq, + bound_expr.term, + bound_literals.into_iter().next().unwrap(), + ))); + } + } + &PredicateOperator::NotIn => { + if bound_literals.is_empty() { + return Ok(BoundPredicate::AlwaysTrue); + } + if bound_literals.len() == 1 { + return Ok(BoundPredicate::Binary(BinaryExpression::new( + PredicateOperator::NotEq, + bound_expr.term, + bound_literals.into_iter().next().unwrap(), + ))); + } + } + op => { + return Err(Error::new( + ErrorKind::Unexpected, + format!("Expecting unary operator,but found {op}"), + )) + } + } + + Ok(BoundPredicate::Set(SetExpression::new( + bound_expr.op, + bound_expr.term, + bound_literals, + ))) + } + } + } +} + impl Display for Predicate { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { @@ -292,7 +483,11 @@ impl Predicate { impl Not for Predicate { type Output = Predicate; - /// Create a predicate which is the reverse of this predicate. For example: `NOT (a > 10)` + /// Create a predicate which is the reverse of this predicate. For example: `NOT (a > 10)`. + /// + /// This is different from [`Predicate::negate()`] since it doesn't rewrite expression, but + /// just adds a `NOT` operator. + /// /// # Example /// ///```rust @@ -332,12 +527,46 @@ pub enum BoundPredicate { Set(SetExpression), } +impl Display for BoundPredicate { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + BoundPredicate::AlwaysTrue => { + write!(f, "True") + } + BoundPredicate::AlwaysFalse => { + write!(f, "False") + } + BoundPredicate::And(expr) => { + write!(f, "({}) AND ({})", expr.inputs()[0], expr.inputs()[1]) + } + BoundPredicate::Or(expr) => { + write!(f, "({}) OR ({})", expr.inputs()[0], expr.inputs()[1]) + } + BoundPredicate::Not(expr) => { + write!(f, "NOT ({})", expr.inputs()[0]) + } + BoundPredicate::Unary(expr) => { + write!(f, "{}", expr) + } + BoundPredicate::Binary(expr) => { + write!(f, "{}", expr) + } + BoundPredicate::Set(expr) => { + write!(f, "{}", expr) + } + } + } +} + #[cfg(test)] mod tests { + use std::ops::Not; + use std::sync::Arc; + + use crate::expr::Bind; use crate::expr::Reference; use crate::spec::Datum; - use std::collections::HashSet; - use std::ops::Not; + use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type}; #[test] fn test_predicate_negate_and() { @@ -406,13 +635,239 @@ mod tests { #[test] fn test_predicate_negate_set() { - let expression = Reference::new("a").is_in(HashSet::from([Datum::long(5), Datum::long(6)])); + let expression = Reference::new("a").is_in([Datum::long(5), Datum::long(6)]); - let expected = - Reference::new("a").is_not_in(HashSet::from([Datum::long(5), Datum::long(6)])); + let expected = Reference::new("a").is_not_in([Datum::long(5), Datum::long(6)]); let result = expression.negate(); assert_eq!(result, expected); } + + fn table_schema_simple() -> SchemaRef { + Arc::new( + Schema::builder() + .with_schema_id(1) + .with_identifier_field_ids(vec![2]) + .with_fields(vec![ + NestedField::optional(1, "foo", Type::Primitive(PrimitiveType::String)).into(), + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::optional(3, "baz", Type::Primitive(PrimitiveType::Boolean)).into(), + ]) + .build() + .unwrap(), + ) + } + + #[test] + fn test_bind_is_null() { + let schema = table_schema_simple(); + let expr = Reference::new("foo").is_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "foo IS NULL"); + } + + #[test] + fn test_bind_is_null_required() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "False"); + } + + #[test] + fn test_bind_is_not_null() { + let schema = table_schema_simple(); + let expr = Reference::new("foo").is_not_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "foo IS NOT NULL"); + } + + #[test] + fn test_bind_is_not_null_required() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_not_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "True"); + } + + #[test] + fn test_bind_less_than() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").less_than(Datum::int(10)); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar < 10"); + } + + #[test] + fn test_bind_less_than_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").less_than(Datum::string("abcd")); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_greater_than_or_eq() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").greater_than_or_equal_to(Datum::int(10)); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar >= 10"); + } + + #[test] + fn test_bind_greater_than_or_eq_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").greater_than_or_equal_to(Datum::string("abcd")); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_in() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_in([Datum::int(10), Datum::int(20)]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar IN (20, 10)"); + } + + #[test] + fn test_bind_in_empty() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_in(vec![]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "False"); + } + + #[test] + fn test_bind_in_one_literal() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_in(vec![Datum::int(10)]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar = 10"); + } + + #[test] + fn test_bind_in_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_in(vec![Datum::int(10), Datum::string("abcd")]); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_not_in() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_not_in([Datum::int(10), Datum::int(20)]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar NOT IN (20, 10)"); + } + + #[test] + fn test_bind_not_in_empty() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_not_in(vec![]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "True"); + } + + #[test] + fn test_bind_not_in_one_literal() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_not_in(vec![Datum::int(10)]); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "bar != 10"); + } + + #[test] + fn test_bind_not_in_wrong_type() { + let schema = table_schema_simple(); + let expr = Reference::new("bar").is_not_in([Datum::int(10), Datum::string("abcd")]); + let bound_expr = expr.bind(schema, true); + assert!(bound_expr.is_err()); + } + + #[test] + fn test_bind_and() { + let schema = table_schema_simple(); + let expr = Reference::new("bar") + .less_than(Datum::int(10)) + .and(Reference::new("foo").is_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "(bar < 10) AND (foo IS NULL)"); + } + + #[test] + fn test_bind_and_always_false() { + let schema = table_schema_simple(); + let expr = Reference::new("foo") + .less_than(Datum::string("abcd")) + .and(Reference::new("bar").is_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "False"); + } + + #[test] + fn test_bind_and_always_true() { + let schema = table_schema_simple(); + let expr = Reference::new("foo") + .less_than(Datum::string("abcd")) + .and(Reference::new("bar").is_not_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), r#"foo < "abcd""#); + } + + #[test] + fn test_bind_or() { + let schema = table_schema_simple(); + let expr = Reference::new("bar") + .less_than(Datum::int(10)) + .or(Reference::new("foo").is_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "(bar < 10) OR (foo IS NULL)"); + } + + #[test] + fn test_bind_or_always_true() { + let schema = table_schema_simple(); + let expr = Reference::new("foo") + .less_than(Datum::string("abcd")) + .or(Reference::new("bar").is_not_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "True"); + } + + #[test] + fn test_bind_or_always_false() { + let schema = table_schema_simple(); + let expr = Reference::new("foo") + .less_than(Datum::string("abcd")) + .or(Reference::new("bar").is_null()); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), r#"foo < "abcd""#); + } + + #[test] + fn test_bind_not() { + let schema = table_schema_simple(); + let expr = !Reference::new("bar").less_than(Datum::int(10)); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "NOT (bar < 10)"); + } + + #[test] + fn test_bind_not_always_true() { + let schema = table_schema_simple(); + let expr = !Reference::new("bar").is_not_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), "False"); + } + + #[test] + fn test_bind_not_always_false() { + let schema = table_schema_simple(); + let expr = !Reference::new("bar").is_null(); + let bound_expr = expr.bind(schema, true).unwrap(); + assert_eq!(&format!("{bound_expr}"), r#"True"#); + } } diff --git a/crates/iceberg/src/expr/term.rs b/crates/iceberg/src/expr/term.rs index 6be502fff1..e39c97e0f6 100644 --- a/crates/iceberg/src/expr/term.rs +++ b/crates/iceberg/src/expr/term.rs @@ -17,11 +17,15 @@ //! Term definition. -use crate::expr::{BinaryExpression, Predicate, PredicateOperator, SetExpression, UnaryExpression}; -use crate::spec::{Datum, NestedField, NestedFieldRef}; -use std::collections::HashSet; use std::fmt::{Display, Formatter}; +use fnv::FnvHashSet; + +use crate::expr::Bind; +use crate::expr::{BinaryExpression, Predicate, PredicateOperator, SetExpression, UnaryExpression}; +use crate::spec::{Datum, NestedField, NestedFieldRef, SchemaRef}; +use crate::{Error, ErrorKind}; + /// Unbound term before binding to a schema. pub type Term = Reference; @@ -123,16 +127,20 @@ impl Reference { /// /// ```rust /// - /// use std::collections::HashSet; + /// use fnv::FnvHashSet; /// use iceberg::expr::Reference; /// use iceberg::spec::Datum; - /// let expr = Reference::new("a").is_in(HashSet::from([Datum::long(5), Datum::long(6)])); + /// let expr = Reference::new("a").is_in([Datum::long(5), Datum::long(6)]); /// /// let as_string = format!("{expr}"); /// assert!(&as_string == "a IN (5, 6)" || &as_string == "a IN (6, 5)"); /// ``` - pub fn is_in(self, literals: HashSet) -> Predicate { - Predicate::Set(SetExpression::new(PredicateOperator::In, self, literals)) + pub fn is_in(self, literals: impl IntoIterator) -> Predicate { + Predicate::Set(SetExpression::new( + PredicateOperator::In, + self, + FnvHashSet::from_iter(literals), + )) } /// Creates an is-not-in expression. For example, `a IS NOT IN (5, 6)`. @@ -141,16 +149,20 @@ impl Reference { /// /// ```rust /// - /// use std::collections::HashSet; + /// use fnv::FnvHashSet; /// use iceberg::expr::Reference; /// use iceberg::spec::Datum; - /// let expr = Reference::new("a").is_not_in(HashSet::from([Datum::long(5), Datum::long(6)])); + /// let expr = Reference::new("a").is_not_in([Datum::long(5), Datum::long(6)]); /// /// let as_string = format!("{expr}"); /// assert!(&as_string == "a NOT IN (5, 6)" || &as_string == "a NOT IN (6, 5)"); /// ``` - pub fn is_not_in(self, literals: HashSet) -> Predicate { - Predicate::Set(SetExpression::new(PredicateOperator::NotIn, self, literals)) + pub fn is_not_in(self, literals: impl IntoIterator) -> Predicate { + Predicate::Set(SetExpression::new( + PredicateOperator::NotIn, + self, + FnvHashSet::from_iter(literals), + )) } } @@ -160,8 +172,28 @@ impl Display for Reference { } } +impl Bind for Reference { + type Bound = BoundReference; + + fn bind(self, schema: SchemaRef, case_sensitive: bool) -> crate::Result { + let field = if case_sensitive { + schema.field_by_name(&self.name) + } else { + schema.field_by_name_case_insensitive(&self.name) + }; + + let field = field.ok_or_else(|| { + Error::new( + ErrorKind::DataInvalid, + format!("Field {} not found in schema", self.name), + ) + })?; + Ok(BoundReference::new(self.name, field.clone())) + } +} + /// A named reference in a bound expression after binding to a schema. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct BoundReference { // This maybe different from [`name`] filed in [`NestedField`] since this contains full path. // For example, if the field is `a.b.c`, then `field.name` is `c`, but `original_name` is `a.b.c`. @@ -192,3 +224,67 @@ impl Display for BoundReference { /// Bound term after binding to a schema. pub type BoundTerm = BoundReference; + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use crate::expr::{Bind, BoundReference, Reference}; + use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type}; + + fn table_schema_simple() -> SchemaRef { + Arc::new( + Schema::builder() + .with_schema_id(1) + .with_identifier_field_ids(vec![2]) + .with_fields(vec![ + NestedField::optional(1, "foo", Type::Primitive(PrimitiveType::String)).into(), + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), + NestedField::optional(3, "baz", Type::Primitive(PrimitiveType::Boolean)).into(), + ]) + .build() + .unwrap(), + ) + } + + #[test] + fn test_bind_reference() { + let schema = table_schema_simple(); + let reference = Reference::new("bar").bind(schema, true).unwrap(); + + let expected_ref = BoundReference::new( + "bar", + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), + ); + + assert_eq!(expected_ref, reference); + } + + #[test] + fn test_bind_reference_case_insensitive() { + let schema = table_schema_simple(); + let reference = Reference::new("BAR").bind(schema, false).unwrap(); + + let expected_ref = BoundReference::new( + "BAR", + NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(), + ); + + assert_eq!(expected_ref, reference); + } + + #[test] + fn test_bind_reference_failure() { + let schema = table_schema_simple(); + let result = Reference::new("bar_not_eix").bind(schema, true); + + assert!(result.is_err()); + } + + #[test] + fn test_bind_reference_case_insensitive_failure() { + let schema = table_schema_simple(); + let result = Reference::new("bar_non_exist").bind(schema, false); + assert!(result.is_err()); + } +} diff --git a/crates/iceberg/src/spec/datatypes.rs b/crates/iceberg/src/spec/datatypes.rs index 8f404e928a..6ea4175e51 100644 --- a/crates/iceberg/src/spec/datatypes.rs +++ b/crates/iceberg/src/spec/datatypes.rs @@ -135,6 +135,15 @@ impl Type { ensure_data_valid!(precision > 0 && precision <= MAX_DECIMAL_PRECISION, "Decimals with precision larger than {MAX_DECIMAL_PRECISION} are not supported: {precision}",); Ok(Type::Primitive(PrimitiveType::Decimal { precision, scale })) } + + /// Check if it's float or double type. + #[inline(always)] + pub fn is_floating_type(&self) -> bool { + matches!( + self, + Type::Primitive(PrimitiveType::Float) | Type::Primitive(PrimitiveType::Double) + ) + } } impl From for Type { diff --git a/crates/iceberg/src/spec/schema.rs b/crates/iceberg/src/spec/schema.rs index 34e383f651..975a2a9ef7 100644 --- a/crates/iceberg/src/spec/schema.rs +++ b/crates/iceberg/src/spec/schema.rs @@ -51,6 +51,7 @@ pub struct Schema { id_to_field: HashMap, name_to_id: HashMap, + lowercase_name_to_id: HashMap, id_to_name: HashMap, } @@ -117,6 +118,11 @@ impl SchemaBuilder { index.indexes() }; + let lowercase_name_to_id = name_to_id + .iter() + .map(|(k, v)| (k.to_lowercase(), *v)) + .collect(); + Ok(Schema { r#struct, schema_id: self.schema_id, @@ -127,6 +133,7 @@ impl SchemaBuilder { id_to_field, name_to_id, + lowercase_name_to_id, id_to_name, }) } @@ -212,6 +219,15 @@ impl Schema { .and_then(|id| self.field_by_id(*id)) } + /// Get field by field name, but in case-insensitive way. + /// + /// Both full name and short name could work here. + pub fn field_by_name_case_insensitive(&self, field_name: &str) -> Option<&NestedFieldRef> { + self.lowercase_name_to_id + .get(&field_name.to_lowercase()) + .and_then(|id| self.field_by_id(*id)) + } + /// Get field by alias. pub fn field_by_alias(&self, alias: &str) -> Option<&NestedFieldRef> { self.alias_to_id @@ -1032,6 +1048,42 @@ table { assert_eq!(&expected_name_to_id, &schema.name_to_id); } + #[test] + fn test_schema_index_by_name_case_insensitive() { + let expected_name_to_id = HashMap::from( + [ + ("fOo", 1), + ("Bar", 2), + ("BAz", 3), + ("quX", 4), + ("quX.ELEment", 5), + ("qUUx", 6), + ("QUUX.KEY", 7), + ("QUUX.Value", 8), + ("qUUX.VALUE.Key", 9), + ("qUux.VaLue.Value", 10), + ("lOCAtION", 11), + ("LOCAtioN.ELeMENt", 12), + ("LoCATion.element.LATitude", 13), + ("locatION.ElemeNT.LONgitude", 14), + ("LOCAtiON.LATITUDE", 13), + ("LOCATION.LONGITUDE", 14), + ("PERSon", 15), + ("PERSON.Name", 16), + ("peRSON.AGe", 17), + ] + .map(|e| (e.0.to_string(), e.1)), + ); + + let schema = table_schema_nested(); + for (name, id) in expected_name_to_id { + assert_eq!( + Some(id), + schema.field_by_name_case_insensitive(&name).map(|f| f.id) + ); + } + } + #[test] fn test_schema_find_column_name() { let expected_column_name = HashMap::from([ diff --git a/crates/iceberg/src/spec/values.rs b/crates/iceberg/src/spec/values.rs index 113620f1c3..00f2e57d2b 100644 --- a/crates/iceberg/src/spec/values.rs +++ b/crates/iceberg/src/spec/values.rs @@ -106,7 +106,7 @@ impl Display for Datum { (_, PrimitiveLiteral::TimestampTZ(val)) => { write!(f, "{}", microseconds_to_datetimetz(*val)) } - (_, PrimitiveLiteral::String(val)) => write!(f, "{}", val), + (_, PrimitiveLiteral::String(val)) => write!(f, r#""{}""#, val), (_, PrimitiveLiteral::UUID(val)) => write!(f, "{}", val), (_, PrimitiveLiteral::Fixed(val)) => display_bytes(val, f), (_, PrimitiveLiteral::Binary(val)) => display_bytes(val, f), @@ -529,7 +529,7 @@ impl Datum { /// use iceberg::spec::Datum; /// let t = Datum::string("ss"); /// - /// assert_eq!(&format!("{t}"), "ss"); + /// assert_eq!(&format!("{t}"), r#""ss""#); /// ``` pub fn string(s: S) -> Self { Self { @@ -658,6 +658,21 @@ impl Datum { unreachable!("Decimal type must be primitive.") } } + + /// Convert the datum to `target_type`. + pub fn to(self, target_type: &Type) -> Result { + // TODO: We should allow more type conversions + match target_type { + Type::Primitive(typ) if typ == &self.r#type => Ok(self), + _ => Err(Error::new( + ErrorKind::DataInvalid, + format!( + "Can't convert datum from {} type to {} type.", + self.r#type, target_type + ), + )), + } + } } /// Values present in iceberg type