@@ -178,6 +178,9 @@ impl DynamicFilterPhysicalExpr {
178178 )
179179 } ) ?;
180180 // Remap the children of the new expression to match the original children
181+ // We still do this again in `current()` but doing it preventively here
182+ // reduces the work needed in some cases if `current()` is called multiple times
183+ // and the same externally facing `PhysicalExpr` is used for both `with_new_children` and `update()`.`
181184 let new_expr = Self :: remap_children (
182185 & self . children ,
183186 self . remapped_children . as_ref ( ) ,
@@ -310,31 +313,105 @@ mod test {
310313 Field :: new( "a" , DataType :: Int32 , false ) ,
311314 Field :: new( "b" , DataType :: Int32 , false ) ,
312315 ] ) ) ;
313- let file_schema = Arc :: new ( Schema :: new ( vec ! [
314- Field :: new( "b" , DataType :: Int32 , false ) ,
315- Field :: new( "a" , DataType :: Int32 , false ) ,
316- ] ) ) ;
317316 let expr = Arc :: new ( BinaryExpr :: new (
318317 col ( "a" , & table_schema) . unwrap ( ) ,
319- datafusion_expr:: Operator :: Gt ,
318+ datafusion_expr:: Operator :: Eq ,
320319 lit ( 42 ) as Arc < dyn PhysicalExpr > ,
321320 ) ) ;
322321 let dynamic_filter = Arc :: new ( DynamicFilterPhysicalExpr :: new (
323322 vec ! [ col( "a" , & table_schema) . unwrap( ) ] ,
324323 expr as Arc < dyn PhysicalExpr > ,
325324 ) ) ;
326- // Take an initial snapshot
327- let snap = dynamic_filter. snapshot ( ) . unwrap ( ) . unwrap ( ) ;
328- insta:: assert_snapshot!( format!( "{snap:?}" ) , @r#"BinaryExpr { left: Column { name: "a", index: 0 }, op: Gt, right: Literal { value: Int32(42) }, fail_on_overflow: false }"# ) ;
329- let snap_string = snap. to_string ( ) ;
330- // Remap the children to the file schema
331- let dynamic_filter =
332- reassign_predicate_columns ( dynamic_filter, & file_schema, false ) . unwrap ( ) ;
333- // Take a snapshot after remapping, the children in the snapshot should be remapped to the file schema
334- let new_snap = dynamic_filter. snapshot ( ) . unwrap ( ) . unwrap ( ) ;
335- insta:: assert_snapshot!( format!( "{new_snap:?}" ) , @r#"BinaryExpr { left: Column { name: "a", index: 1 }, op: Gt, right: Literal { value: Int32(42) }, fail_on_overflow: false }"# ) ;
336- // The original snapshot should not have changed
337- assert_eq ! ( snap. to_string( ) , snap_string) ;
325+ // Simulate two `ParquetSource` files with different filter schemas
326+ // Both of these should hit the same inner `PhysicalExpr` even after `update()` is called
327+ // and be able to remap children independently.
328+ let filter_schema_1 = Arc :: new ( Schema :: new ( vec ! [
329+ Field :: new( "a" , DataType :: Int32 , false ) ,
330+ Field :: new( "b" , DataType :: Int32 , false ) ,
331+ ] ) ) ;
332+ let filter_schema_2 = Arc :: new ( Schema :: new ( vec ! [
333+ Field :: new( "b" , DataType :: Int32 , false ) ,
334+ Field :: new( "a" , DataType :: Int32 , false ) ,
335+ ] ) ) ;
336+ // Each ParquetExec calls `with_new_children` on the DynamicFilterPhysicalExpr
337+ // and remaps the children to the file schema.
338+ let dynamic_filter_1 = reassign_predicate_columns (
339+ Arc :: clone ( & dynamic_filter) as Arc < dyn PhysicalExpr > ,
340+ & filter_schema_1,
341+ false ,
342+ )
343+ . unwrap ( ) ;
344+ let snap = dynamic_filter_1. snapshot ( ) . unwrap ( ) . unwrap ( ) ;
345+ insta:: assert_snapshot!( format!( "{snap:?}" ) , @r#"BinaryExpr { left: Column { name: "a", index: 0 }, op: Eq, right: Literal { value: Int32(42) }, fail_on_overflow: false }"# ) ;
346+ let dynamic_filter_2 = reassign_predicate_columns (
347+ Arc :: clone ( & dynamic_filter) as Arc < dyn PhysicalExpr > ,
348+ & filter_schema_2,
349+ false ,
350+ )
351+ . unwrap ( ) ;
352+ let snap = dynamic_filter_2. snapshot ( ) . unwrap ( ) . unwrap ( ) ;
353+ insta:: assert_snapshot!( format!( "{snap:?}" ) , @r#"BinaryExpr { left: Column { name: "a", index: 1 }, op: Eq, right: Literal { value: Int32(42) }, fail_on_overflow: false }"# ) ;
354+ // Both filters allow evaluating the same expression
355+ let batch_1 = RecordBatch :: try_new (
356+ Arc :: clone ( & filter_schema_1) ,
357+ vec ! [
358+ // a
359+ ScalarValue :: Int32 ( Some ( 42 ) ) . to_array_of_size( 1 ) . unwrap( ) ,
360+ // b
361+ ScalarValue :: Int32 ( Some ( 43 ) ) . to_array_of_size( 1 ) . unwrap( ) ,
362+ ] ,
363+ )
364+ . unwrap ( ) ;
365+ let batch_2 = RecordBatch :: try_new (
366+ Arc :: clone ( & filter_schema_2) ,
367+ vec ! [
368+ // b
369+ ScalarValue :: Int32 ( Some ( 43 ) ) . to_array_of_size( 1 ) . unwrap( ) ,
370+ // a
371+ ScalarValue :: Int32 ( Some ( 42 ) ) . to_array_of_size( 1 ) . unwrap( ) ,
372+ ] ,
373+ )
374+ . unwrap ( ) ;
375+ // Evaluate the expression on both batches
376+ let result_1 = dynamic_filter_1. evaluate ( & batch_1) . unwrap ( ) ;
377+ let result_2 = dynamic_filter_2. evaluate ( & batch_2) . unwrap ( ) ;
378+ // Check that the results are the same
379+ let ColumnarValue :: Array ( arr_1) = result_1 else {
380+ panic ! ( "Expected ColumnarValue::Array" ) ;
381+ } ;
382+ let ColumnarValue :: Array ( arr_2) = result_2 else {
383+ panic ! ( "Expected ColumnarValue::Array" ) ;
384+ } ;
385+ assert ! ( arr_1. eq( & arr_2) ) ;
386+ let expected = ScalarValue :: Boolean ( Some ( true ) )
387+ . to_array_of_size ( 1 )
388+ . unwrap ( ) ;
389+ assert ! ( arr_1. eq( & expected) ) ;
390+ // Now lets update the expression
391+ // Note that we update the *original* expression and that should be reflected in both the derived expressions
392+ let new_expr = Arc :: new ( BinaryExpr :: new (
393+ col ( "a" , & table_schema) . unwrap ( ) ,
394+ datafusion_expr:: Operator :: Gt ,
395+ lit ( 43 ) as Arc < dyn PhysicalExpr > ,
396+ ) ) ;
397+ dynamic_filter
398+ . update ( Arc :: clone ( & new_expr) as Arc < dyn PhysicalExpr > )
399+ . expect ( "Failed to update expression" ) ;
400+ // Now we should be able to evaluate the new expression on both batches
401+ let result_1 = dynamic_filter_1. evaluate ( & batch_1) . unwrap ( ) ;
402+ let result_2 = dynamic_filter_2. evaluate ( & batch_2) . unwrap ( ) ;
403+ // Check that the results are the same
404+ let ColumnarValue :: Array ( arr_1) = result_1 else {
405+ panic ! ( "Expected ColumnarValue::Array" ) ;
406+ } ;
407+ let ColumnarValue :: Array ( arr_2) = result_2 else {
408+ panic ! ( "Expected ColumnarValue::Array" ) ;
409+ } ;
410+ assert ! ( arr_1. eq( & arr_2) ) ;
411+ let expected = ScalarValue :: Boolean ( Some ( false ) )
412+ . to_array_of_size ( 1 )
413+ . unwrap ( ) ;
414+ assert ! ( arr_1. eq( & expected) ) ;
338415 }
339416
340417 #[ test]
0 commit comments