@@ -343,6 +343,16 @@ impl RowSelection {
343343 intersect_row_selections ( & self . selectors , & other. selectors )
344344 }
345345
346+ /// Compute the union of two [`RowSelection`]
347+ /// For example:
348+ /// self: NNYYYYNNYYNYN
349+ /// other: NYNNNNNNN
350+ ///
351+ /// returned: NYYYYYNNYYNYN
352+ pub fn union ( & self , other : & Self ) -> Self {
353+ union_row_selections ( & self . selectors , & other. selectors )
354+ }
355+
346356 /// Returns `true` if this [`RowSelection`] selects any rows
347357 pub fn selects_any ( & self ) -> bool {
348358 self . selectors . iter ( ) . any ( |x| !x. skip )
@@ -536,6 +546,92 @@ fn intersect_row_selections(left: &[RowSelector], right: &[RowSelector]) -> RowS
536546 iter. collect ( )
537547}
538548
549+ /// Combine two lists of `RowSelector` return the union of them
550+ /// For example:
551+ /// self: NNYYYYNNYYNYN
552+ /// other: NYNNNNNNY
553+ ///
554+ /// returned: NYYYYYNNYYNYN
555+ ///
556+ /// This can be removed from here once RowSelection::union is in parquet::arrow
557+ fn union_row_selections ( left : & [ RowSelector ] , right : & [ RowSelector ] ) -> RowSelection {
558+ let mut l_iter = left. iter ( ) . copied ( ) . peekable ( ) ;
559+ let mut r_iter = right. iter ( ) . copied ( ) . peekable ( ) ;
560+
561+ let iter = std:: iter:: from_fn ( move || {
562+ loop {
563+ let l = l_iter. peek_mut ( ) ;
564+ let r = r_iter. peek_mut ( ) ;
565+
566+ match ( l, r) {
567+ ( Some ( a) , _) if a. row_count == 0 => {
568+ l_iter. next ( ) . unwrap ( ) ;
569+ }
570+ ( _, Some ( b) ) if b. row_count == 0 => {
571+ r_iter. next ( ) . unwrap ( ) ;
572+ }
573+ ( Some ( l) , Some ( r) ) => {
574+ return match ( l. skip , r. skip ) {
575+ // Skip both ranges
576+ ( true , true ) => {
577+ if l. row_count < r. row_count {
578+ let skip = l. row_count ;
579+ r. row_count -= l. row_count ;
580+ l_iter. next ( ) ;
581+ Some ( RowSelector :: skip ( skip) )
582+ } else {
583+ let skip = r. row_count ;
584+ l. row_count -= skip;
585+ r_iter. next ( ) ;
586+ Some ( RowSelector :: skip ( skip) )
587+ }
588+ }
589+ // Keep rows from left
590+ ( false , true ) => {
591+ if l. row_count < r. row_count {
592+ r. row_count -= l. row_count ;
593+ l_iter. next ( )
594+ } else {
595+ let r_row_count = r. row_count ;
596+ l. row_count -= r_row_count;
597+ r_iter. next ( ) ;
598+ Some ( RowSelector :: select ( r_row_count) )
599+ }
600+ }
601+ // Keep rows from right
602+ ( true , false ) => {
603+ if l. row_count < r. row_count {
604+ let l_row_count = l. row_count ;
605+ r. row_count -= l_row_count;
606+ l_iter. next ( ) ;
607+ Some ( RowSelector :: select ( l_row_count) )
608+ } else {
609+ l. row_count -= r. row_count ;
610+ r_iter. next ( )
611+ }
612+ }
613+ // Keep at least one
614+ _ => {
615+ if l. row_count < r. row_count {
616+ r. row_count -= l. row_count ;
617+ l_iter. next ( )
618+ } else {
619+ l. row_count -= r. row_count ;
620+ r_iter. next ( )
621+ }
622+ }
623+ } ;
624+ }
625+ ( Some ( _) , None ) => return l_iter. next ( ) ,
626+ ( None , Some ( _) ) => return r_iter. next ( ) ,
627+ ( None , None ) => return None ,
628+ }
629+ }
630+ } ) ;
631+
632+ iter. collect ( )
633+ }
634+
539635#[ cfg( test) ]
540636mod tests {
541637 use super :: * ;
@@ -1213,4 +1309,40 @@ mod tests {
12131309 ]
12141310 ) ;
12151311 }
1312+
1313+ #[ test]
1314+ fn test_union ( ) {
1315+ let selection = RowSelection :: from ( vec ! [ RowSelector :: select( 1048576 ) ] ) ;
1316+ let result = selection. union ( & selection) ;
1317+ assert_eq ! ( result, selection) ;
1318+
1319+ // NYNYY
1320+ let a = RowSelection :: from ( vec ! [
1321+ RowSelector :: skip( 10 ) ,
1322+ RowSelector :: select( 10 ) ,
1323+ RowSelector :: skip( 10 ) ,
1324+ RowSelector :: select( 20 ) ,
1325+ ] ) ;
1326+
1327+ // NNYYNYN
1328+ let b = RowSelection :: from ( vec ! [
1329+ RowSelector :: skip( 20 ) ,
1330+ RowSelector :: select( 20 ) ,
1331+ RowSelector :: skip( 10 ) ,
1332+ RowSelector :: select( 10 ) ,
1333+ RowSelector :: skip( 10 ) ,
1334+ ] ) ;
1335+
1336+ let result = a. union ( & b) ;
1337+
1338+ // NYYYYYN
1339+ assert_eq ! (
1340+ result. iter( ) . collect:: <Vec <_>>( ) ,
1341+ vec![
1342+ & RowSelector :: skip( 10 ) ,
1343+ & RowSelector :: select( 50 ) ,
1344+ & RowSelector :: skip( 10 ) ,
1345+ ]
1346+ ) ;
1347+ }
12161348}
0 commit comments