@@ -23,7 +23,14 @@ use crate::expr::{BoundReference, PredicateOperator, Reference};
2323use crate :: spec:: Datum ;
2424use itertools:: Itertools ;
2525use std:: collections:: HashSet ;
26+ use crate :: error:: Result ;
27+ use crate :: expr:: { Bind , BoundReference , PredicateOperator , Reference } ;
28+ use crate :: spec:: { Datum , SchemaRef } ;
29+ use crate :: { Error , ErrorKind } ;
30+ use fnv:: FnvHashSet ;
31+
2632use std:: fmt:: { Debug , Display , Formatter } ;
33+ use std:: mem:: MaybeUninit ;
2734use std:: ops:: Not ;
2835
2936/// Logical expression, such as `AND`, `OR`, `NOT`.
@@ -55,6 +62,29 @@ impl<T, const N: usize> LogicalExpression<T, N> {
5562 }
5663}
5764
65+ impl < T : Bind , const N : usize > Bind for LogicalExpression < T , N > {
66+ type Bound = LogicalExpression < T :: Bound , N > ;
67+
68+ fn bind ( self , schema : SchemaRef , case_sensitive : bool ) -> Result < Self :: Bound > {
69+ let mut bound_inputs = MaybeUninit :: < [ Box < T :: Bound > ; N ] > :: uninit ( ) ;
70+ for ( i, input) in self . inputs . into_iter ( ) . enumerate ( ) {
71+ let input = input. bind ( schema. clone ( ) , case_sensitive) ?;
72+ // SAFETY: The pointer is valid from [`MaybeUninit`].
73+ unsafe {
74+ bound_inputs
75+ . as_mut_ptr ( )
76+ . cast :: < Box < T :: Bound > > ( )
77+ . add ( i)
78+ . write ( Box :: new ( input) ) ;
79+ }
80+ }
81+
82+ // SAFETY: We have initialized all elements of the array.
83+ let bound_inputs = unsafe { bound_inputs. assume_init ( ) } ;
84+ Ok ( LogicalExpression :: new ( bound_inputs) )
85+ }
86+ }
87+
5888/// Unary predicate, for example, `a IS NULL`.
5989#[ derive( PartialEq ) ]
6090pub struct UnaryExpression < T > {
@@ -79,6 +109,15 @@ impl<T: Display> Display for UnaryExpression<T> {
79109 }
80110}
81111
112+ impl < T : Bind > Bind for UnaryExpression < T > {
113+ type Bound = UnaryExpression < T :: Bound > ;
114+
115+ fn bind ( self , schema : SchemaRef , case_sensitive : bool ) -> Result < Self :: Bound > {
116+ let bound_term = self . term . bind ( schema, case_sensitive) ?;
117+ Ok ( UnaryExpression :: new ( self . op , bound_term) )
118+ }
119+ }
120+
82121impl < T > UnaryExpression < T > {
83122 pub ( crate ) fn new ( op : PredicateOperator , term : T ) -> Self {
84123 debug_assert ! ( op. is_unary( ) ) ;
@@ -120,6 +159,15 @@ impl<T: Display> Display for BinaryExpression<T> {
120159 }
121160}
122161
162+ impl < T : Bind > Bind for BinaryExpression < T > {
163+ type Bound = BinaryExpression < T :: Bound > ;
164+
165+ fn bind ( self , schema : SchemaRef , case_sensitive : bool ) -> Result < Self :: Bound > {
166+ let bound_term = self . term . bind ( schema. clone ( ) , case_sensitive) ?;
167+ Ok ( BinaryExpression :: new ( self . op , bound_term, self . literal ) )
168+ }
169+ }
170+
123171/// Set predicates, for example, `a in (1, 2, 3)`.
124172#[ derive( PartialEq ) ]
125173pub struct SetExpression < T > {
@@ -128,7 +176,7 @@ pub struct SetExpression<T> {
128176 /// Term of this predicate, for example, `a` in `a in (1, 2, 3)`.
129177 term : T ,
130178 /// Literals of this predicate, for example, `(1, 2, 3)` in `a in (1, 2, 3)`.
131- literals : HashSet < Datum > ,
179+ literals : FnvHashSet < Datum > ,
132180}
133181
134182impl < T : Debug > Debug for SetExpression < T > {
@@ -141,12 +189,22 @@ impl<T: Debug> Debug for SetExpression<T> {
141189 }
142190}
143191
144- impl < T : Debug > SetExpression < T > {
145- pub ( crate ) fn new ( op : PredicateOperator , term : T , literals : HashSet < Datum > ) -> Self {
192+ impl < T > SetExpression < T > {
193+ pub ( crate ) fn new ( op : PredicateOperator , term : T , literals : FnvHashSet < Datum > ) -> Self {
194+ debug_assert ! ( op. is_set( ) ) ;
146195 Self { op, term, literals }
147196 }
148197}
149198
199+ impl < T : Bind > Bind for SetExpression < T > {
200+ type Bound = SetExpression < T :: Bound > ;
201+
202+ fn bind ( self , schema : SchemaRef , case_sensitive : bool ) -> Result < Self :: Bound > {
203+ let bound_term = self . term . bind ( schema. clone ( ) , case_sensitive) ?;
204+ Ok ( SetExpression :: new ( self . op , bound_term, self . literals ) )
205+ }
206+ }
207+
150208impl < T : Display + Debug > Display for SetExpression < T > {
151209 fn fmt ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
152210 let mut literal_strs = self . literals . iter ( ) . map ( |l| format ! ( "{}" , l) ) ;
@@ -172,6 +230,146 @@ pub enum Predicate {
172230 Set ( SetExpression < Reference > ) ,
173231}
174232
233+ impl Bind for Predicate {
234+ type Bound = BoundPredicate ;
235+
236+ fn bind ( self , schema : SchemaRef , case_sensitive : bool ) -> Result < BoundPredicate > {
237+ match self {
238+ Predicate :: And ( expr) => {
239+ let bound_expr = expr. bind ( schema, case_sensitive) ?;
240+
241+ let [ left, right] = bound_expr. inputs ;
242+ Ok ( match ( left, right) {
243+ ( _, r) if matches ! ( & * r, & BoundPredicate :: AlwaysFalse ) => {
244+ BoundPredicate :: AlwaysFalse
245+ }
246+ ( l, _) if matches ! ( & * l, & BoundPredicate :: AlwaysFalse ) => {
247+ BoundPredicate :: AlwaysFalse
248+ }
249+ ( left, r) if matches ! ( & * r, & BoundPredicate :: AlwaysTrue ) => * left,
250+ ( l, right) if matches ! ( & * l, & BoundPredicate :: AlwaysTrue ) => * right,
251+ ( left, right) => BoundPredicate :: And ( LogicalExpression :: new ( [ left, right] ) ) ,
252+ } )
253+ }
254+ Predicate :: Not ( expr) => {
255+ let bound_expr = expr. bind ( schema, case_sensitive) ?;
256+ let [ inner] = bound_expr. inputs ;
257+ Ok ( match inner {
258+ e if matches ! ( & * e, & BoundPredicate :: AlwaysTrue ) => BoundPredicate :: AlwaysFalse ,
259+ e if matches ! ( & * e, & BoundPredicate :: AlwaysFalse ) => BoundPredicate :: AlwaysTrue ,
260+ e => BoundPredicate :: Not ( LogicalExpression :: new ( [ e] ) ) ,
261+ } )
262+ }
263+ Predicate :: Or ( expr) => {
264+ let bound_expr = expr. bind ( schema, case_sensitive) ?;
265+ let [ left, right] = bound_expr. inputs ;
266+ Ok ( match ( left, right) {
267+ ( _, r) if matches ! ( & * r, & BoundPredicate :: AlwaysTrue ) => {
268+ BoundPredicate :: AlwaysTrue
269+ }
270+ ( l, _) if matches ! ( & * l, & BoundPredicate :: AlwaysTrue ) => {
271+ BoundPredicate :: AlwaysTrue
272+ }
273+ ( left, r) if matches ! ( & * r, & BoundPredicate :: AlwaysFalse ) => * left,
274+ ( l, right) if matches ! ( & * l, & BoundPredicate :: AlwaysFalse ) => * right,
275+ ( left, right) => BoundPredicate :: Or ( LogicalExpression :: new ( [ left, right] ) ) ,
276+ } )
277+ }
278+ Predicate :: Unary ( expr) => {
279+ let bound_expr = expr. bind ( schema, case_sensitive) ?;
280+
281+ match & bound_expr. op {
282+ & PredicateOperator :: IsNull => {
283+ if bound_expr. term . field ( ) . required {
284+ return Ok ( BoundPredicate :: AlwaysFalse ) ;
285+ }
286+ }
287+ & PredicateOperator :: NotNull => {
288+ if bound_expr. term . field ( ) . required {
289+ return Ok ( BoundPredicate :: AlwaysTrue ) ;
290+ }
291+ }
292+ & PredicateOperator :: IsNan | & PredicateOperator :: NotNan => {
293+ if !bound_expr. term . field ( ) . field_type . is_floating_type ( ) {
294+ return Err ( Error :: new (
295+ ErrorKind :: DataInvalid ,
296+ format ! (
297+ "Expecting floating point type, but found {}" ,
298+ bound_expr. term. field( ) . field_type
299+ ) ,
300+ ) ) ;
301+ }
302+ }
303+ op => {
304+ return Err ( Error :: new (
305+ ErrorKind :: Unexpected ,
306+ format ! ( "Expecting unary operator,but found {op}" ) ,
307+ ) )
308+ }
309+ }
310+
311+ Ok ( BoundPredicate :: Unary ( bound_expr) )
312+ }
313+ Predicate :: Binary ( expr) => {
314+ let bound_expr = expr. bind ( schema, case_sensitive) ?;
315+ let bound_literal = bound_expr. literal . to ( & bound_expr. term . field ( ) . field_type ) ?;
316+ Ok ( BoundPredicate :: Binary ( BinaryExpression :: new (
317+ bound_expr. op ,
318+ bound_expr. term ,
319+ bound_literal,
320+ ) ) )
321+ }
322+ Predicate :: Set ( expr) => {
323+ let bound_expr = expr. bind ( schema, case_sensitive) ?;
324+ let bound_literals = bound_expr
325+ . literals
326+ . into_iter ( )
327+ . map ( |l| l. to ( & bound_expr. term . field ( ) . field_type ) )
328+ . collect :: < Result < FnvHashSet < Datum > > > ( ) ?;
329+
330+ match & bound_expr. op {
331+ & PredicateOperator :: In => {
332+ if bound_literals. is_empty ( ) {
333+ return Ok ( BoundPredicate :: AlwaysFalse ) ;
334+ }
335+ if bound_literals. len ( ) == 1 {
336+ return Ok ( BoundPredicate :: Binary ( BinaryExpression :: new (
337+ PredicateOperator :: Eq ,
338+ bound_expr. term ,
339+ bound_literals. into_iter ( ) . next ( ) . unwrap ( ) ,
340+ ) ) ) ;
341+ }
342+ }
343+ & PredicateOperator :: NotIn => {
344+ if bound_literals. is_empty ( ) {
345+ return Ok ( BoundPredicate :: AlwaysTrue ) ;
346+ }
347+ if bound_literals. len ( ) == 1 {
348+ return Ok ( BoundPredicate :: Binary ( BinaryExpression :: new (
349+ PredicateOperator :: NotEq ,
350+ bound_expr. term ,
351+ bound_literals. into_iter ( ) . next ( ) . unwrap ( ) ,
352+ ) ) ) ;
353+ }
354+ }
355+ op => {
356+ return Err ( Error :: new (
357+ ErrorKind :: Unexpected ,
358+ format ! ( "Expecting unary operator,but found {op}" ) ,
359+ ) )
360+ }
361+ }
362+
363+ Ok ( BoundPredicate :: Set ( SetExpression :: new (
364+ bound_expr. op ,
365+ bound_expr. term ,
366+ bound_literals,
367+ ) ) )
368+ }
369+ }
370+ }
371+ }
372+
175373impl Display for Predicate {
176374 fn fmt ( & self , f : & mut Formatter < ' _ > ) -> std:: fmt:: Result {
177375 match self {
@@ -415,4 +613,31 @@ mod tests {
415613
416614 assert_eq ! ( result, expected) ;
417615 }
616+
617+ use crate :: expr:: { Bind , Reference } ;
618+ use crate :: spec:: { NestedField , PrimitiveType , Schema , SchemaRef , Type } ;
619+ use std:: sync:: Arc ;
620+
621+ fn table_schema_simple ( ) -> SchemaRef {
622+ Arc :: new (
623+ Schema :: builder ( )
624+ . with_schema_id ( 1 )
625+ . with_identifier_field_ids ( vec ! [ 2 ] )
626+ . with_fields ( vec ! [
627+ NestedField :: optional( 1 , "foo" , Type :: Primitive ( PrimitiveType :: String ) ) . into( ) ,
628+ NestedField :: required( 2 , "bar" , Type :: Primitive ( PrimitiveType :: Int ) ) . into( ) ,
629+ NestedField :: optional( 3 , "baz" , Type :: Primitive ( PrimitiveType :: Boolean ) ) . into( ) ,
630+ ] )
631+ . build ( )
632+ . unwrap ( ) ,
633+ )
634+ }
635+
636+ #[ test]
637+ fn test_bind_is_null ( ) {
638+ let schema = table_schema_simple ( ) ;
639+ let expr = Reference :: new ( "foo" ) . is_null ( ) ;
640+ let bound_expr = expr. bind ( schema, true ) . unwrap ( ) ;
641+ assert_eq ! ( & format!( "{bound_expr}" ) , "foo IS NULL" ) ;
642+ }
418643}
0 commit comments