Skip to content

Commit 298946e

Browse files
committed
feat: Implement binding expression
1 parent f61d475 commit 298946e

File tree

8 files changed

+439
-4
lines changed

8 files changed

+439
-4
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ chrono = "0.4"
4141
derive_builder = "0.20.0"
4242
either = "1"
4343
env_logger = "0.11.0"
44+
fnv = "1"
4445
futures = "0.3"
4546
iceberg = { version = "0.2.0", path = "./crates/iceberg" }
4647
iceberg-catalog-rest = { version = "0.2.0", path = "./crates/catalog/rest" }

crates/iceberg/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ bitvec = { workspace = true }
4040
chrono = { workspace = true }
4141
derive_builder = { workspace = true }
4242
either = { workspace = true }
43+
fnv = { workspace = true }
4344
futures = { workspace = true }
4445
itertools = { workspace = true }
4546
lazy_static = { workspace = true }

crates/iceberg/src/expr/mod.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ use std::fmt::{Display, Formatter};
2323

2424
pub use term::*;
2525
mod predicate;
26+
27+
use crate::spec::SchemaRef;
2628
pub use predicate::*;
2729

2830
/// Predicate operators used in expressions.
@@ -147,6 +149,14 @@ impl PredicateOperator {
147149
}
148150
}
149151

152+
/// Bind expression to a schema.
153+
pub trait Bind {
154+
/// The type of the bounded result.
155+
type Bound;
156+
/// Bind an expression to a schema.
157+
fn bind(self, schema: SchemaRef, case_sensitive: bool) -> crate::Result<Self::Bound>;
158+
}
159+
150160
#[cfg(test)]
151161
mod tests {
152162
use crate::expr::PredicateOperator;

crates/iceberg/src/expr/predicate.rs

Lines changed: 228 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,14 @@ use crate::expr::{BoundReference, PredicateOperator, Reference};
2323
use crate::spec::Datum;
2424
use itertools::Itertools;
2525
use std::collections::HashSet;
26+
use crate::error::Result;
27+
use crate::expr::{Bind, BoundReference, PredicateOperator, Reference};
28+
use crate::spec::{Datum, SchemaRef};
29+
use crate::{Error, ErrorKind};
30+
use fnv::FnvHashSet;
31+
2632
use std::fmt::{Debug, Display, Formatter};
33+
use std::mem::MaybeUninit;
2734
use std::ops::Not;
2835

2936
/// Logical expression, such as `AND`, `OR`, `NOT`.
@@ -55,6 +62,29 @@ impl<T, const N: usize> LogicalExpression<T, N> {
5562
}
5663
}
5764

65+
impl<T: Bind, const N: usize> Bind for LogicalExpression<T, N> {
66+
type Bound = LogicalExpression<T::Bound, N>;
67+
68+
fn bind(self, schema: SchemaRef, case_sensitive: bool) -> Result<Self::Bound> {
69+
let mut bound_inputs = MaybeUninit::<[Box<T::Bound>; N]>::uninit();
70+
for (i, input) in self.inputs.into_iter().enumerate() {
71+
let input = input.bind(schema.clone(), case_sensitive)?;
72+
// SAFETY: The pointer is valid from [`MaybeUninit`].
73+
unsafe {
74+
bound_inputs
75+
.as_mut_ptr()
76+
.cast::<Box<T::Bound>>()
77+
.add(i)
78+
.write(Box::new(input));
79+
}
80+
}
81+
82+
// SAFETY: We have initialized all elements of the array.
83+
let bound_inputs = unsafe { bound_inputs.assume_init() };
84+
Ok(LogicalExpression::new(bound_inputs))
85+
}
86+
}
87+
5888
/// Unary predicate, for example, `a IS NULL`.
5989
#[derive(PartialEq)]
6090
pub struct UnaryExpression<T> {
@@ -79,6 +109,15 @@ impl<T: Display> Display for UnaryExpression<T> {
79109
}
80110
}
81111

112+
impl<T: Bind> Bind for UnaryExpression<T> {
113+
type Bound = UnaryExpression<T::Bound>;
114+
115+
fn bind(self, schema: SchemaRef, case_sensitive: bool) -> Result<Self::Bound> {
116+
let bound_term = self.term.bind(schema, case_sensitive)?;
117+
Ok(UnaryExpression::new(self.op, bound_term))
118+
}
119+
}
120+
82121
impl<T> UnaryExpression<T> {
83122
pub(crate) fn new(op: PredicateOperator, term: T) -> Self {
84123
debug_assert!(op.is_unary());
@@ -120,6 +159,15 @@ impl<T: Display> Display for BinaryExpression<T> {
120159
}
121160
}
122161

162+
impl<T: Bind> Bind for BinaryExpression<T> {
163+
type Bound = BinaryExpression<T::Bound>;
164+
165+
fn bind(self, schema: SchemaRef, case_sensitive: bool) -> Result<Self::Bound> {
166+
let bound_term = self.term.bind(schema.clone(), case_sensitive)?;
167+
Ok(BinaryExpression::new(self.op, bound_term, self.literal))
168+
}
169+
}
170+
123171
/// Set predicates, for example, `a in (1, 2, 3)`.
124172
#[derive(PartialEq)]
125173
pub struct SetExpression<T> {
@@ -128,7 +176,7 @@ pub struct SetExpression<T> {
128176
/// Term of this predicate, for example, `a` in `a in (1, 2, 3)`.
129177
term: T,
130178
/// Literals of this predicate, for example, `(1, 2, 3)` in `a in (1, 2, 3)`.
131-
literals: HashSet<Datum>,
179+
literals: FnvHashSet<Datum>,
132180
}
133181

134182
impl<T: Debug> Debug for SetExpression<T> {
@@ -141,12 +189,22 @@ impl<T: Debug> Debug for SetExpression<T> {
141189
}
142190
}
143191

144-
impl<T: Debug> SetExpression<T> {
145-
pub(crate) fn new(op: PredicateOperator, term: T, literals: HashSet<Datum>) -> Self {
192+
impl<T> SetExpression<T> {
193+
pub(crate) fn new(op: PredicateOperator, term: T, literals: FnvHashSet<Datum>) -> Self {
194+
debug_assert!(op.is_set());
146195
Self { op, term, literals }
147196
}
148197
}
149198

199+
impl<T: Bind> Bind for SetExpression<T> {
200+
type Bound = SetExpression<T::Bound>;
201+
202+
fn bind(self, schema: SchemaRef, case_sensitive: bool) -> Result<Self::Bound> {
203+
let bound_term = self.term.bind(schema.clone(), case_sensitive)?;
204+
Ok(SetExpression::new(self.op, bound_term, self.literals))
205+
}
206+
}
207+
150208
impl<T: Display + Debug> Display for SetExpression<T> {
151209
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
152210
let mut literal_strs = self.literals.iter().map(|l| format!("{}", l));
@@ -172,6 +230,146 @@ pub enum Predicate {
172230
Set(SetExpression<Reference>),
173231
}
174232

233+
impl Bind for Predicate {
234+
type Bound = BoundPredicate;
235+
236+
fn bind(self, schema: SchemaRef, case_sensitive: bool) -> Result<BoundPredicate> {
237+
match self {
238+
Predicate::And(expr) => {
239+
let bound_expr = expr.bind(schema, case_sensitive)?;
240+
241+
let [left, right] = bound_expr.inputs;
242+
Ok(match (left, right) {
243+
(_, r) if matches!(&*r, &BoundPredicate::AlwaysFalse) => {
244+
BoundPredicate::AlwaysFalse
245+
}
246+
(l, _) if matches!(&*l, &BoundPredicate::AlwaysFalse) => {
247+
BoundPredicate::AlwaysFalse
248+
}
249+
(left, r) if matches!(&*r, &BoundPredicate::AlwaysTrue) => *left,
250+
(l, right) if matches!(&*l, &BoundPredicate::AlwaysTrue) => *right,
251+
(left, right) => BoundPredicate::And(LogicalExpression::new([left, right])),
252+
})
253+
}
254+
Predicate::Not(expr) => {
255+
let bound_expr = expr.bind(schema, case_sensitive)?;
256+
let [inner] = bound_expr.inputs;
257+
Ok(match inner {
258+
e if matches!(&*e, &BoundPredicate::AlwaysTrue) => BoundPredicate::AlwaysFalse,
259+
e if matches!(&*e, &BoundPredicate::AlwaysFalse) => BoundPredicate::AlwaysTrue,
260+
e => BoundPredicate::Not(LogicalExpression::new([e])),
261+
})
262+
}
263+
Predicate::Or(expr) => {
264+
let bound_expr = expr.bind(schema, case_sensitive)?;
265+
let [left, right] = bound_expr.inputs;
266+
Ok(match (left, right) {
267+
(_, r) if matches!(&*r, &BoundPredicate::AlwaysTrue) => {
268+
BoundPredicate::AlwaysTrue
269+
}
270+
(l, _) if matches!(&*l, &BoundPredicate::AlwaysTrue) => {
271+
BoundPredicate::AlwaysTrue
272+
}
273+
(left, r) if matches!(&*r, &BoundPredicate::AlwaysFalse) => *left,
274+
(l, right) if matches!(&*l, &BoundPredicate::AlwaysFalse) => *right,
275+
(left, right) => BoundPredicate::Or(LogicalExpression::new([left, right])),
276+
})
277+
}
278+
Predicate::Unary(expr) => {
279+
let bound_expr = expr.bind(schema, case_sensitive)?;
280+
281+
match &bound_expr.op {
282+
&PredicateOperator::IsNull => {
283+
if bound_expr.term.field().required {
284+
return Ok(BoundPredicate::AlwaysFalse);
285+
}
286+
}
287+
&PredicateOperator::NotNull => {
288+
if bound_expr.term.field().required {
289+
return Ok(BoundPredicate::AlwaysTrue);
290+
}
291+
}
292+
&PredicateOperator::IsNan | &PredicateOperator::NotNan => {
293+
if !bound_expr.term.field().field_type.is_floating_type() {
294+
return Err(Error::new(
295+
ErrorKind::DataInvalid,
296+
format!(
297+
"Expecting floating point type, but found {}",
298+
bound_expr.term.field().field_type
299+
),
300+
));
301+
}
302+
}
303+
op => {
304+
return Err(Error::new(
305+
ErrorKind::Unexpected,
306+
format!("Expecting unary operator,but found {op}"),
307+
))
308+
}
309+
}
310+
311+
Ok(BoundPredicate::Unary(bound_expr))
312+
}
313+
Predicate::Binary(expr) => {
314+
let bound_expr = expr.bind(schema, case_sensitive)?;
315+
let bound_literal = bound_expr.literal.to(&bound_expr.term.field().field_type)?;
316+
Ok(BoundPredicate::Binary(BinaryExpression::new(
317+
bound_expr.op,
318+
bound_expr.term,
319+
bound_literal,
320+
)))
321+
}
322+
Predicate::Set(expr) => {
323+
let bound_expr = expr.bind(schema, case_sensitive)?;
324+
let bound_literals = bound_expr
325+
.literals
326+
.into_iter()
327+
.map(|l| l.to(&bound_expr.term.field().field_type))
328+
.collect::<Result<FnvHashSet<Datum>>>()?;
329+
330+
match &bound_expr.op {
331+
&PredicateOperator::In => {
332+
if bound_literals.is_empty() {
333+
return Ok(BoundPredicate::AlwaysFalse);
334+
}
335+
if bound_literals.len() == 1 {
336+
return Ok(BoundPredicate::Binary(BinaryExpression::new(
337+
PredicateOperator::Eq,
338+
bound_expr.term,
339+
bound_literals.into_iter().next().unwrap(),
340+
)));
341+
}
342+
}
343+
&PredicateOperator::NotIn => {
344+
if bound_literals.is_empty() {
345+
return Ok(BoundPredicate::AlwaysTrue);
346+
}
347+
if bound_literals.len() == 1 {
348+
return Ok(BoundPredicate::Binary(BinaryExpression::new(
349+
PredicateOperator::NotEq,
350+
bound_expr.term,
351+
bound_literals.into_iter().next().unwrap(),
352+
)));
353+
}
354+
}
355+
op => {
356+
return Err(Error::new(
357+
ErrorKind::Unexpected,
358+
format!("Expecting unary operator,but found {op}"),
359+
))
360+
}
361+
}
362+
363+
Ok(BoundPredicate::Set(SetExpression::new(
364+
bound_expr.op,
365+
bound_expr.term,
366+
bound_literals,
367+
)))
368+
}
369+
}
370+
}
371+
}
372+
175373
impl Display for Predicate {
176374
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
177375
match self {
@@ -415,4 +613,31 @@ mod tests {
415613

416614
assert_eq!(result, expected);
417615
}
616+
617+
use crate::expr::{Bind, Reference};
618+
use crate::spec::{NestedField, PrimitiveType, Schema, SchemaRef, Type};
619+
use std::sync::Arc;
620+
621+
fn table_schema_simple() -> SchemaRef {
622+
Arc::new(
623+
Schema::builder()
624+
.with_schema_id(1)
625+
.with_identifier_field_ids(vec![2])
626+
.with_fields(vec![
627+
NestedField::optional(1, "foo", Type::Primitive(PrimitiveType::String)).into(),
628+
NestedField::required(2, "bar", Type::Primitive(PrimitiveType::Int)).into(),
629+
NestedField::optional(3, "baz", Type::Primitive(PrimitiveType::Boolean)).into(),
630+
])
631+
.build()
632+
.unwrap(),
633+
)
634+
}
635+
636+
#[test]
637+
fn test_bind_is_null() {
638+
let schema = table_schema_simple();
639+
let expr = Reference::new("foo").is_null();
640+
let bound_expr = expr.bind(schema, true).unwrap();
641+
assert_eq!(&format!("{bound_expr}"), "foo IS NULL");
642+
}
418643
}

0 commit comments

Comments
 (0)