@@ -528,25 +528,6 @@ impl AggregateUDFImpl for ForeignAggregateUDF {
528528 }
529529}
530530
531- #[ cfg( test) ]
532- mod tests {
533- use super :: * ;
534-
535- #[ test]
536- fn test_round_trip_udaf ( ) -> Result < ( ) > {
537- let original_udaf = datafusion:: functions_aggregate:: sum:: Sum :: new ( ) ;
538- let original_udaf = Arc :: new ( AggregateUDF :: from ( original_udaf) ) ;
539-
540- let local_udaf: FFI_AggregateUDF = Arc :: clone ( & original_udaf) . into ( ) ;
541-
542- let foreign_udaf: ForeignAggregateUDF = ( & local_udaf) . try_into ( ) ?;
543-
544- assert ! ( original_udaf. name( ) == foreign_udaf. name( ) ) ;
545-
546- Ok ( ( ) )
547- }
548- }
549-
550531#[ repr( C ) ]
551532#[ derive( Debug , StableAbi ) ]
552533#[ allow( non_camel_case_types) ]
@@ -575,3 +556,152 @@ impl From<AggregateOrderSensitivity> for FFI_AggregateOrderSensitivity {
575556 }
576557 }
577558}
559+
560+ #[ cfg( test) ]
561+ mod tests {
562+ use arrow:: datatypes:: Schema ;
563+ use datafusion:: {
564+ common:: create_array,
565+ functions_aggregate:: sum:: Sum ,
566+ physical_expr:: { LexOrdering , PhysicalSortExpr } ,
567+ physical_plan:: expressions:: col,
568+ scalar:: ScalarValue ,
569+ } ;
570+
571+ use super :: * ;
572+
573+ fn create_test_foreign_udaf (
574+ original_udaf : impl AggregateUDFImpl + ' static ,
575+ ) -> Result < AggregateUDF > {
576+ let original_udaf = Arc :: new ( AggregateUDF :: from ( original_udaf) ) ;
577+
578+ let local_udaf: FFI_AggregateUDF = Arc :: clone ( & original_udaf) . into ( ) ;
579+
580+ let foreign_udaf: ForeignAggregateUDF = ( & local_udaf) . try_into ( ) ?;
581+ Ok ( foreign_udaf. into ( ) )
582+ }
583+
584+ #[ test]
585+ fn test_round_trip_udaf ( ) -> Result < ( ) > {
586+ let original_udaf = Sum :: new ( ) ;
587+ let original_name = original_udaf. name ( ) . to_owned ( ) ;
588+
589+ let foreign_udaf = create_test_foreign_udaf ( original_udaf) ?;
590+ // let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
591+
592+ // let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
593+
594+ // let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
595+ // let foreign_udaf: AggregateUDF = foreign_udaf.into();
596+
597+ assert_eq ! ( original_name, foreign_udaf. name( ) ) ;
598+ Ok ( ( ) )
599+ }
600+
601+ #[ test]
602+ fn test_foreign_udaf_aliases ( ) -> Result < ( ) > {
603+ let foreign_udaf =
604+ create_test_foreign_udaf ( Sum :: new ( ) ) ?. with_aliases ( [ "my_function" ] ) ;
605+
606+ let return_type = foreign_udaf. return_type ( & [ DataType :: Float64 ] ) ?;
607+ assert_eq ! ( return_type, DataType :: Float64 ) ;
608+ Ok ( ( ) )
609+ }
610+
611+ #[ test]
612+ fn test_foreign_udaf_accumulator ( ) -> Result < ( ) > {
613+ let foreign_udaf = create_test_foreign_udaf ( Sum :: new ( ) ) ?;
614+
615+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Float64 , true ) ] ) ;
616+ let acc_args = AccumulatorArgs {
617+ return_type : & DataType :: Float64 ,
618+ schema : & schema,
619+ ignore_nulls : true ,
620+ ordering_req : & LexOrdering :: new ( vec ! [ PhysicalSortExpr {
621+ expr: col( "a" , & schema) ?,
622+ options: Default :: default ( ) ,
623+ } ] ) ,
624+ is_reversed : false ,
625+ name : "round_trip" ,
626+ is_distinct : true ,
627+ exprs : & [ col ( "a" , & schema) ?] ,
628+ } ;
629+ let mut accumulator = foreign_udaf. accumulator ( acc_args) ?;
630+ let values = create_array ! ( Float64 , vec![ 10. , 20. , 30. , 40. , 50. ] ) ;
631+ accumulator. update_batch ( & [ values] ) ?;
632+ let resultant_value = accumulator. evaluate ( ) ?;
633+ assert_eq ! ( resultant_value, ScalarValue :: Float64 ( Some ( 150. ) ) ) ;
634+
635+ Ok ( ( ) )
636+ }
637+
638+ #[ test]
639+ fn test_beneficial_ordering ( ) -> Result < ( ) > {
640+ let foreign_udaf = create_test_foreign_udaf (
641+ datafusion:: functions_aggregate:: first_last:: FirstValue :: new ( ) ,
642+ ) ?;
643+
644+ let foreign_udaf = foreign_udaf. with_beneficial_ordering ( true ) ?. unwrap ( ) ;
645+
646+ assert_eq ! (
647+ foreign_udaf. order_sensitivity( ) ,
648+ AggregateOrderSensitivity :: Beneficial
649+ ) ;
650+
651+ let a_field = Field :: new ( "a" , DataType :: Float64 , true ) ;
652+ let state_fields = foreign_udaf. state_fields ( StateFieldsArgs {
653+ name : "a" ,
654+ input_types : & [ DataType :: Float64 ] ,
655+ return_type : & DataType :: Float64 ,
656+ ordering_fields : & [ a_field. clone ( ) ] ,
657+ is_distinct : false ,
658+ } ) ?;
659+
660+ println ! ( "{:#?}" , state_fields) ;
661+ assert_eq ! ( state_fields. len( ) , 3 ) ;
662+ assert_eq ! ( state_fields[ 1 ] , a_field) ;
663+ Ok ( ( ) )
664+ }
665+
666+ #[ test]
667+ fn test_sliding_accumulator ( ) -> Result < ( ) > {
668+ let foreign_udaf = create_test_foreign_udaf ( Sum :: new ( ) ) ?;
669+
670+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Float64 , true ) ] ) ;
671+ let acc_args = AccumulatorArgs {
672+ return_type : & DataType :: Float64 ,
673+ schema : & schema,
674+ ignore_nulls : true ,
675+ ordering_req : & LexOrdering :: new ( vec ! [ PhysicalSortExpr {
676+ expr: col( "a" , & schema) ?,
677+ options: Default :: default ( ) ,
678+ } ] ) ,
679+ is_reversed : false ,
680+ name : "round_trip" ,
681+ is_distinct : true ,
682+ exprs : & [ col ( "a" , & schema) ?] ,
683+ } ;
684+
685+ let mut accumulator = foreign_udaf. create_sliding_accumulator ( acc_args) ?;
686+ let values = create_array ! ( Float64 , vec![ 10. , 20. , 30. , 40. , 50. ] ) ;
687+ accumulator. update_batch ( & [ values] ) ?;
688+ let resultant_value = accumulator. evaluate ( ) ?;
689+ assert_eq ! ( resultant_value, ScalarValue :: Float64 ( Some ( 150. ) ) ) ;
690+
691+ Ok ( ( ) )
692+ }
693+
694+ fn test_round_trip_order_sensitivity ( sensitivity : AggregateOrderSensitivity ) {
695+ let ffi_sensitivity: FFI_AggregateOrderSensitivity = sensitivity. into ( ) ;
696+ let round_trip_sensitivity: AggregateOrderSensitivity = ffi_sensitivity. into ( ) ;
697+
698+ assert_eq ! ( sensitivity, round_trip_sensitivity) ;
699+ }
700+
701+ #[ test]
702+ fn test_round_trip_all_order_sensitivities ( ) {
703+ test_round_trip_order_sensitivity ( AggregateOrderSensitivity :: Insensitive ) ;
704+ test_round_trip_order_sensitivity ( AggregateOrderSensitivity :: HardRequirement ) ;
705+ test_round_trip_order_sensitivity ( AggregateOrderSensitivity :: Beneficial ) ;
706+ }
707+ }
0 commit comments