@@ -129,7 +129,7 @@ use fallible_iterator::FallibleIterator;
129
129
use futures:: { ready, Stream } ;
130
130
use pin_project:: { pin_project, pinned_drop} ;
131
131
use postgres_protocol:: escape:: { escape_identifier, escape_literal} ;
132
- use postgres_protocol:: message:: backend:: { Message , ReplicationMessage } ;
132
+ use postgres_protocol:: message:: backend:: { Message , ReplicationMessage , RowDescriptionBody } ;
133
133
use postgres_protocol:: message:: frontend;
134
134
use std:: marker:: PhantomPinned ;
135
135
use std:: path:: { Path , PathBuf } ;
@@ -251,7 +251,7 @@ impl CreateReplicationSlotResponse {
251
251
252
252
/// Response sent after streaming from a timeline that is not the
253
253
/// current timeline.
254
- #[ derive( Debug ) ]
254
+ #[ derive( Clone , Debug ) ]
255
255
pub struct ReplicationResponse {
256
256
next_tli : u64 ,
257
257
next_tli_startpos : Lsn ,
@@ -272,15 +272,11 @@ impl ReplicationResponse {
272
272
/// Represents a client connected in replication mode.
273
273
pub struct ReplicationClient {
274
274
client : Client ,
275
- replication_stream_active : bool ,
276
275
}
277
276
278
277
impl ReplicationClient {
279
278
pub ( crate ) fn new ( client : Client ) -> ReplicationClient {
280
- ReplicationClient {
281
- client : client,
282
- replication_stream_active : false ,
283
- }
279
+ ReplicationClient { client : client }
284
280
}
285
281
}
286
282
@@ -657,30 +653,51 @@ impl ReplicationClient {
657
653
& ' a mut self ,
658
654
command : String ,
659
655
) -> Result < Pin < Box < ReplicationStream < ' a > > > , Error > {
656
+ let mut copyboth_received = false ;
657
+ let mut replication_response: Option < ReplicationResponse > = None ;
660
658
let mut responses = self . send ( & command) . await ?;
661
- self . replication_stream_active = true ;
662
659
660
+ // Before we construct the ReplicationStream, we must know
661
+ // whether the server entered copy mode or not. Otherwise, if
662
+ // the ReplicationStream were to be dropped, we wouldn't know
663
+ // whether to send a CopyDone message or not (and it would be
664
+ // bad to try to receive and process the responses during the
665
+ // destructor).
666
+
667
+ // If the timeline selected is the current one, the server
668
+ // will always enter copy mode. If the timeline is historic,
669
+ // and if there is no work to do, the server will skip copy
670
+ // mode and immediately send a response tuple.
663
671
match responses. next ( ) . await ? {
664
- Message :: CopyBothResponse ( _) => { }
672
+ Message :: CopyBothResponse ( _) => {
673
+ copyboth_received = true ;
674
+ }
675
+ Message :: RowDescription ( rowdesc) => {
676
+ // Never entered copy mode, so don't bother returning
677
+ // a stream, just process the response.
678
+ replication_response =
679
+ Some ( recv_replication_response ( & mut responses, rowdesc) . await ?) ;
680
+ }
665
681
m => return Err ( Error :: unexpected_message ( m) ) ,
666
682
}
667
683
668
684
Ok ( Box :: pin ( ReplicationStream {
669
685
rclient : self ,
670
686
responses : responses,
687
+ copyboth_received : copyboth_received,
688
+ copydone_sent : false ,
689
+ copydone_received : false ,
690
+ replication_response : replication_response,
671
691
_phantom_pinned : PhantomPinned ,
672
692
} ) )
673
693
}
674
694
675
695
fn send_copydone ( & mut self ) -> Result < ( ) , Error > {
676
- if self . replication_stream_active {
677
- let iclient = self . client . inner ( ) ;
678
- let mut buf = BytesMut :: new ( ) ;
679
- frontend:: copy_done ( & mut buf) ;
680
- iclient
681
- . unpipelined_send ( RequestMessages :: Single ( FrontendMessage :: Raw ( buf. freeze ( ) ) ) ) ?;
682
- self . replication_stream_active = false ;
683
- }
696
+ let iclient = self . client . inner ( ) ;
697
+ let mut buf = BytesMut :: new ( ) ;
698
+ frontend:: copy_done ( & mut buf) ;
699
+ iclient. unpipelined_send ( RequestMessages :: Single ( FrontendMessage :: Raw ( buf. freeze ( ) ) ) ) ?;
700
+
684
701
Ok ( ( ) )
685
702
}
686
703
}
@@ -690,71 +707,76 @@ impl ReplicationClient {
690
707
/// [CopyData](postgres_protocol::message::backend::Message::CopyData).
691
708
///
692
709
/// Intended to be used with the [next()](tokio::stream::StreamExt::next) method.
710
+ ///
711
+ /// If the timeline specified with
712
+ /// [start_physical_replication()](ReplicationClient::start_physical_replication)
713
+ /// or
714
+ /// [start_logical_replication()](ReplicationClient::start_logical_replication())
715
+ /// is the current timeline, the stream is indefinite, and must be
716
+ /// terminated with
717
+ /// [stop_replication()](ReplicationStream::stop_replication()) (which
718
+ /// will not return a response tuple); or by dropping the
719
+ /// [ReplicationStream](ReplicationStream).
720
+ ///
721
+ /// If the timeline is not the current timeline, the stream will
722
+ /// terminate when the end of the timeline is reached, and
723
+ /// [stop_replication()](ReplicationStream::stop_replication()) will
724
+ /// return a response tuple.
693
725
#[ pin_project( PinnedDrop ) ]
694
726
pub struct ReplicationStream < ' a > {
695
727
rclient : & ' a mut ReplicationClient ,
696
728
responses : Responses ,
729
+ copyboth_received : bool ,
730
+ copydone_sent : bool ,
731
+ copydone_received : bool ,
732
+ replication_response : Option < ReplicationResponse > ,
697
733
#[ pin]
698
734
_phantom_pinned : PhantomPinned ,
699
735
}
700
736
701
737
impl ReplicationStream < ' _ > {
702
738
/// Stop replication stream and return the replication client object.
703
- pub async fn stop_replication ( mut self : Pin < Box < Self > > ) -> Result < Option < ReplicationResponse > , Error > {
739
+ pub async fn stop_replication (
740
+ mut self : Pin < Box < Self > > ,
741
+ ) -> Result < Option < ReplicationResponse > , Error > {
704
742
let this = self . as_mut ( ) . project ( ) ;
705
743
706
- this. rclient . send_copydone ( ) ?;
707
- let responses = this. responses ;
708
-
709
- // drain remaining CopyData messages and CopyDone
710
- loop {
711
- match responses. next ( ) . await ? {
712
- Message :: CopyData ( _) => ( ) ,
713
- Message :: CopyDone => break ,
714
- m => return Err ( Error :: unexpected_message ( m) ) ,
715
- }
744
+ if this. replication_response . is_some ( ) {
745
+ return Ok ( this. replication_response . clone ( ) ) ;
716
746
}
717
747
718
- let next_message = responses. next ( ) . await ?;
748
+ // we must be in copy mode; shut it down
749
+ assert ! ( * this. copyboth_received) ;
750
+ if !* this. copydone_sent {
751
+ this. rclient . send_copydone ( ) ?;
752
+ * this. copydone_sent = true ;
753
+ }
719
754
720
- let response = match next_message {
721
- Message :: RowDescription ( rowdesc) => {
722
- let datarow = match responses. next ( ) . await ? {
723
- Message :: DataRow ( m) => m,
755
+ // If server didn't already shut down copy, drain remaining
756
+ // CopyData and the CopyDone.
757
+ if !* this. copydone_received {
758
+ loop {
759
+ match this. responses . next ( ) . await ? {
760
+ Message :: CopyData ( _) => ( ) ,
761
+ Message :: CopyDone => {
762
+ * this. copydone_received = true ;
763
+ break ;
764
+ }
724
765
m => return Err ( Error :: unexpected_message ( m) ) ,
725
- } ;
726
-
727
- let fields = rowdesc. fields ( ) . collect :: < Vec < _ > > ( ) . map_err ( Error :: parse) ?;
728
- let ranges = datarow. ranges ( ) . collect :: < Vec < _ > > ( ) . map_err ( Error :: parse) ?;
729
-
730
- assert_eq ! ( fields. len( ) , 2 ) ;
731
- assert_eq ! ( fields[ 0 ] . type_oid( ) , Type :: INT8 . oid( ) ) ;
732
- assert_eq ! ( fields[ 0 ] . format( ) , 0 ) ;
733
- assert_eq ! ( fields[ 1 ] . type_oid( ) , Type :: TEXT . oid( ) ) ;
734
- assert_eq ! ( fields[ 1 ] . format( ) , 0 ) ;
735
- assert_eq ! ( ranges. len( ) , 2 ) ;
736
-
737
- let timeline = & datarow. buffer ( ) [ ranges[ 0 ] . to_owned ( ) . unwrap ( ) ] ;
738
- let switch = & datarow. buffer ( ) [ ranges[ 1 ] . to_owned ( ) . unwrap ( ) ] ;
739
- Some ( ReplicationResponse {
740
- next_tli : from_utf8 ( timeline) . unwrap ( ) . parse :: < u64 > ( ) . unwrap ( ) ,
741
- next_tli_startpos : Lsn :: from ( from_utf8 ( switch) . unwrap ( ) ) ,
742
- } )
766
+ }
743
767
}
744
- Message :: CommandComplete ( _) => None ,
745
- m => return Err ( Error :: unexpected_message ( m) ) ,
746
- } ;
768
+ }
747
769
748
- match responses. next ( ) . await ? {
770
+ match this. responses . next ( ) . await ? {
771
+ Message :: RowDescription ( rowdesc) => {
772
+ * this. replication_response =
773
+ Some ( recv_replication_response ( this. responses , rowdesc) . await ?) ;
774
+ }
749
775
Message :: CommandComplete ( _) => ( ) ,
750
776
m => return Err ( Error :: unexpected_message ( m) ) ,
751
- } ;
752
- match responses. next ( ) . await ? {
753
- Message :: ReadyForQuery ( _) => ( ) ,
754
- m => return Err ( Error :: unexpected_message ( m) ) ,
755
- } ;
777
+ }
756
778
757
- Ok ( response )
779
+ Ok ( this . replication_response . clone ( ) )
758
780
}
759
781
}
760
782
@@ -763,14 +785,27 @@ impl Stream for ReplicationStream<'_> {
763
785
764
786
fn poll_next ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Option < Self :: Item > > {
765
787
let this = self . project ( ) ;
766
- let responses = this. responses ;
767
788
768
- match ready ! ( responses. poll_next( cx) ?) {
789
+ // if we already got a replication response tuple, we're done
790
+ if this. replication_response . is_some ( ) {
791
+ return Poll :: Ready ( None ) ;
792
+ }
793
+
794
+ // we are in copy mode
795
+ assert ! ( * this. copyboth_received) ;
796
+ assert ! ( !* this. copydone_sent) ;
797
+ assert ! ( !* this. copydone_received) ;
798
+ match ready ! ( this. responses. poll_next( cx) ?) {
769
799
Message :: CopyData ( body) => {
770
800
let r = ReplicationMessage :: parse ( & body. into_bytes ( ) ) ;
771
801
Poll :: Ready ( Some ( r. map_err ( Error :: parse) ) )
772
802
}
773
- Message :: CopyDone => Poll :: Ready ( None ) ,
803
+ Message :: CopyDone => {
804
+ * this. copydone_received = true ;
805
+ this. rclient . send_copydone ( ) ?;
806
+ * this. copydone_sent = true ;
807
+ Poll :: Ready ( None )
808
+ }
774
809
m => Poll :: Ready ( Some ( Err ( Error :: unexpected_message ( m) ) ) ) ,
775
810
}
776
811
}
@@ -780,6 +815,39 @@ impl Stream for ReplicationStream<'_> {
780
815
impl PinnedDrop for ReplicationStream < ' _ > {
781
816
fn drop ( mut self : Pin < & mut Self > ) {
782
817
let this = self . project ( ) ;
783
- this. rclient . send_copydone ( ) . unwrap ( ) ;
818
+ if * this. copyboth_received && !* this. copydone_sent {
819
+ this. rclient . send_copydone ( ) . unwrap ( ) ;
820
+ * this. copydone_sent = true ;
821
+ }
822
+ }
823
+ }
824
+
825
+ // Read a replication response tuple from the server. This function
826
+ // assumes that the caller has already consumed the RowDescription
827
+ // from the stream.
828
+ async fn recv_replication_response (
829
+ responses : & mut Responses ,
830
+ rowdesc : RowDescriptionBody ,
831
+ ) -> Result < ReplicationResponse , Error > {
832
+ let fields = rowdesc. fields ( ) . collect :: < Vec < _ > > ( ) . map_err ( Error :: parse) ?;
833
+ assert_eq ! ( fields. len( ) , 2 ) ;
834
+ assert_eq ! ( fields[ 0 ] . type_oid( ) , Type :: INT8 . oid( ) ) ;
835
+ assert_eq ! ( fields[ 0 ] . format( ) , 0 ) ;
836
+ assert_eq ! ( fields[ 1 ] . type_oid( ) , Type :: TEXT . oid( ) ) ;
837
+ assert_eq ! ( fields[ 1 ] . format( ) , 0 ) ;
838
+
839
+ match responses. next ( ) . await ? {
840
+ Message :: DataRow ( datarow) => {
841
+ let ranges = datarow. ranges ( ) . collect :: < Vec < _ > > ( ) . map_err ( Error :: parse) ?;
842
+ assert_eq ! ( ranges. len( ) , 2 ) ;
843
+
844
+ let timeline = & datarow. buffer ( ) [ ranges[ 0 ] . to_owned ( ) . unwrap ( ) ] ;
845
+ let switch = & datarow. buffer ( ) [ ranges[ 1 ] . to_owned ( ) . unwrap ( ) ] ;
846
+ Ok ( ReplicationResponse {
847
+ next_tli : from_utf8 ( timeline) . unwrap ( ) . parse :: < u64 > ( ) . unwrap ( ) ,
848
+ next_tli_startpos : Lsn :: from ( from_utf8 ( switch) . unwrap ( ) ) ,
849
+ } )
850
+ }
851
+ m => Err ( Error :: unexpected_message ( m) ) ,
784
852
}
785
853
}
0 commit comments