diff --git a/crates/iceberg/src/expr/predicate.rs b/crates/iceberg/src/expr/predicate.rs index 67a46e2b11..bd325a54c4 100644 --- a/crates/iceberg/src/expr/predicate.rs +++ b/crates/iceberg/src/expr/predicate.rs @@ -116,6 +116,10 @@ impl UnaryExpression { debug_assert!(op.is_unary()); Self { op, term } } + + pub(crate) fn op(&self) -> PredicateOperator { + self.op + } } /// Binary predicate, for example, `a > 10`. @@ -144,6 +148,14 @@ impl BinaryExpression { debug_assert!(op.is_binary()); Self { op, term, literal } } + + pub(crate) fn op(&self) -> PredicateOperator { + self.op + } + + pub(crate) fn literal(&self) -> &Datum { + &self.literal + } } impl Display for BinaryExpression { @@ -187,6 +199,14 @@ impl SetExpression { debug_assert!(op.is_set()); Self { op, term, literals } } + + pub(crate) fn op(&self) -> PredicateOperator { + self.op + } + + pub(crate) fn literals(&self) -> &FnvHashSet { + &self.literals + } } impl Bind for SetExpression { diff --git a/crates/iceberg/src/spec/transform.rs b/crates/iceberg/src/spec/transform.rs index 839d582dc0..ec5d281b9a 100644 --- a/crates/iceberg/src/spec/transform.rs +++ b/crates/iceberg/src/spec/transform.rs @@ -18,12 +18,20 @@ //! Transforms in iceberg. use crate::error::{Error, Result}; +use crate::expr::{ + BinaryExpression, BoundPredicate, Predicate, PredicateOperator, Reference, SetExpression, + UnaryExpression, +}; use crate::spec::datatypes::{PrimitiveType, Type}; +use crate::transform::create_transform_function; use crate::ErrorKind; +use fnv::FnvHashSet; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use std::fmt::{Display, Formatter}; use std::str::FromStr; +use super::Datum; + /// Transform is used to transform predicates to partition predicates, /// in addition to transforming data values. /// @@ -261,6 +269,60 @@ impl Transform { _ => self == other, } } + /// Projects predicate to `Transform` + pub fn project(&self, name: String, pred: &BoundPredicate) -> Result> { + let func = create_transform_function(self)?; + + // TODO: Support other transforms + let projection = match self { + Transform::Bucket(_) => match pred { + BoundPredicate::Unary(expr) => Some(Predicate::Unary(UnaryExpression::new( + expr.op(), + Reference::new(name), + ))), + BoundPredicate::Binary(expr) => { + if expr.op() != PredicateOperator::Eq { + return Ok(None); + } + + let result = func.transform(expr.literal().to_arrow_array())?; + + Some(Predicate::Binary(BinaryExpression::new( + expr.op(), + Reference::new(name), + Datum::from_arrow_array(&result)?, + ))) + } + BoundPredicate::Set(expr) => { + if expr.op() != PredicateOperator::In { + return Ok(None); + } + + let projected_set: Result> = expr + .literals() + .iter() + .map(|lit| { + func.transform(lit.to_arrow_array()) + .and_then(|arr| Datum::from_arrow_array(&arr)) + }) + .collect(); + + match projected_set { + Err(e) => return Err(e), + Ok(set) => Some(Predicate::Set(SetExpression::new( + expr.op(), + Reference::new(name), + set, + ))), + } + } + _ => None, + }, + _ => todo!(), + }; + + Ok(projection) + } } impl Display for Transform { @@ -358,6 +420,14 @@ impl<'de> Deserialize<'de> for Transform { #[cfg(test)] mod tests { + use std::sync::Arc; + + use fnv::FnvHashSet; + + use crate::expr::{ + BinaryExpression, BoundPredicate, BoundReference, Predicate, PredicateOperator, Reference, + SetExpression, UnaryExpression, + }; use crate::spec::datatypes::PrimitiveType::{ Binary, Date, Decimal, Fixed, Int, Long, String as StringType, Time, Timestamp, Timestamptz, Uuid, @@ -365,6 +435,8 @@ mod tests { use crate::spec::datatypes::Type::{Primitive, Struct}; use crate::spec::datatypes::{NestedField, StructType, Type}; use crate::spec::transform::Transform; + use crate::spec::{Datum, PrimitiveType}; + use crate::Result; struct TestParameter { display: String, @@ -398,6 +470,76 @@ mod tests { } } + #[test] + fn test_bucket_project_set() -> Result<()> { + let trans = Transform::Bucket(8); + let name = "projected_name".to_string(); + + let field = NestedField::required(1, "a", Type::Primitive(PrimitiveType::Int)); + + let pred = BoundPredicate::Set(SetExpression::new( + PredicateOperator::In, + BoundReference::new("original_name", Arc::new(field)), + FnvHashSet::from_iter([Datum::int(5), Datum::int(6)]), + )); + + let expected = Some(Predicate::Set(SetExpression::new( + PredicateOperator::In, + Reference::new(&name), + FnvHashSet::from_iter([Datum::int(7), Datum::int(1)]), + ))); + + let result = trans.project(name, &pred)?; + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn test_bucket_project_binary() -> Result<()> { + let trans = Transform::Bucket(8); + let name = "projected_name".to_string(); + + let field = NestedField::required(1, "a", Type::Primitive(PrimitiveType::Int)); + + let pred = BoundPredicate::Binary(BinaryExpression::new( + PredicateOperator::Eq, + BoundReference::new("original_name", Arc::new(field)), + Datum::int(5), + )); + + let expected = Some(Predicate::Binary(BinaryExpression::new( + PredicateOperator::Eq, + Reference::new(&name), + Datum::int(7), + ))); + + let result = trans.project(name, &pred)?; + + assert_eq!(result, expected); + + Ok(()) + } + + #[test] + fn test_bucket_project_unary() { + let trans = Transform::Bucket(8); + + let name = "projected_name".to_string(); + + let field = NestedField::required(1, "a", Type::Primitive(PrimitiveType::Int)); + + let pred = BoundPredicate::Unary(UnaryExpression::new( + PredicateOperator::IsNull, + BoundReference::new("original_name", Arc::new(field)), + )); + + let result = trans.project(name, &pred).unwrap().unwrap(); + + assert_eq!(format!("{result}"), "projected_name IS NULL"); + } + #[test] fn test_bucket_transform() { let trans = Transform::Bucket(8); diff --git a/crates/iceberg/src/spec/values.rs b/crates/iceberg/src/spec/values.rs index 00f2e57d2b..3147a77a78 100644 --- a/crates/iceberg/src/spec/values.rs +++ b/crates/iceberg/src/spec/values.rs @@ -21,8 +21,11 @@ use std::fmt::{Display, Formatter}; use std::str::FromStr; +use std::sync::Arc; use std::{any::Any, collections::BTreeMap}; +use arrow_array::{ArrayRef, Int32Array, Int64Array}; +use arrow_schema::DataType; use bitvec::vec::BitVec; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Utc}; use ordered_float::OrderedFloat; @@ -141,6 +144,54 @@ impl From for Literal { } impl Datum { + /// Convert `Datum` into `arrow_array::ArrayRef` + pub fn to_arrow_array(&self) -> ArrayRef { + // TODO: Support more PrimitiveLiterals + match self.literal { + PrimitiveLiteral::Int(v) => Arc::new(Int32Array::from_value(v, 1)), + PrimitiveLiteral::Long(v) => Arc::new(Int64Array::from_value(v, 1)), + _ => todo!(), + } + } + /// Creates `Datum` from `arrow_array::ArrayRef` + pub fn from_arrow_array(input: &ArrayRef) -> Result { + if input.is_empty() { + return Err(Error::new( + ErrorKind::DataInvalid, + "Input array must not be empty", + )); + } + + let downcast_err = || Error::new(ErrorKind::Unexpected, "Failed to downcast"); + + // TODO: Support more data_types + match input.data_type() { + DataType::Int32 => { + let arr = input + .as_any() + .downcast_ref::() + .ok_or_else(downcast_err)?; + + Ok(Self { + r#type: PrimitiveType::Int, + literal: PrimitiveLiteral::Int(arr.value(0)), + }) + } + DataType::Int64 => { + let arr = input + .as_any() + .downcast_ref::() + .ok_or_else(downcast_err)?; + + Ok(Self { + r#type: PrimitiveType::Long, + literal: PrimitiveLiteral::Long(arr.value(0)), + }) + } + _ => todo!(), + } + } + /// Creates a boolean value. /// /// Example: