Skip to content

Commit 6683d1f

Browse files
committed
handle stop_replication with optional tuple
1 parent fbdddf7 commit 6683d1f

File tree

2 files changed

+134
-66
lines changed

2 files changed

+134
-66
lines changed

tokio-postgres/src/replication_client.rs

Lines changed: 133 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ use fallible_iterator::FallibleIterator;
129129
use futures::{ready, Stream};
130130
use pin_project::{pin_project, pinned_drop};
131131
use postgres_protocol::escape::{escape_identifier, escape_literal};
132-
use postgres_protocol::message::backend::{Message, ReplicationMessage};
132+
use postgres_protocol::message::backend::{Message, ReplicationMessage, RowDescriptionBody};
133133
use postgres_protocol::message::frontend;
134134
use std::marker::PhantomPinned;
135135
use std::path::{Path, PathBuf};
@@ -251,7 +251,7 @@ impl CreateReplicationSlotResponse {
251251

252252
/// Response sent after streaming from a timeline that is not the
253253
/// current timeline.
254-
#[derive(Debug)]
254+
#[derive(Clone, Debug)]
255255
pub struct ReplicationResponse {
256256
next_tli: u64,
257257
next_tli_startpos: Lsn,
@@ -272,15 +272,11 @@ impl ReplicationResponse {
272272
/// Represents a client connected in replication mode.
273273
pub struct ReplicationClient {
274274
client: Client,
275-
replication_stream_active: bool,
276275
}
277276

278277
impl ReplicationClient {
279278
pub(crate) fn new(client: Client) -> ReplicationClient {
280-
ReplicationClient {
281-
client: client,
282-
replication_stream_active: false,
283-
}
279+
ReplicationClient { client: client }
284280
}
285281
}
286282

@@ -657,30 +653,51 @@ impl ReplicationClient {
657653
&'a mut self,
658654
command: String,
659655
) -> Result<Pin<Box<ReplicationStream<'a>>>, Error> {
656+
let mut copyboth_received = false;
657+
let mut replication_response: Option<ReplicationResponse> = None;
660658
let mut responses = self.send(&command).await?;
661-
self.replication_stream_active = true;
662659

660+
// Before we construct the ReplicationStream, we must know
661+
// whether the server entered copy mode or not. Otherwise, if
662+
// the ReplicationStream were to be dropped, we wouldn't know
663+
// whether to send a CopyDone message or not (and it would be
664+
// bad to try to receive and process the responses during the
665+
// destructor).
666+
667+
// If the timeline selected is the current one, the server
668+
// will always enter copy mode. If the timeline is historic,
669+
// and if there is no work to do, the server will skip copy
670+
// mode and immediately send a response tuple.
663671
match responses.next().await? {
664-
Message::CopyBothResponse(_) => {}
672+
Message::CopyBothResponse(_) => {
673+
copyboth_received = true;
674+
}
675+
Message::RowDescription(rowdesc) => {
676+
// Never entered copy mode, so don't bother returning
677+
// a stream, just process the response.
678+
replication_response =
679+
Some(recv_replication_response(&mut responses, rowdesc).await?);
680+
}
665681
m => return Err(Error::unexpected_message(m)),
666682
}
667683

668684
Ok(Box::pin(ReplicationStream {
669685
rclient: self,
670686
responses: responses,
687+
copyboth_received: copyboth_received,
688+
copydone_sent: false,
689+
copydone_received: false,
690+
replication_response: replication_response,
671691
_phantom_pinned: PhantomPinned,
672692
}))
673693
}
674694

675695
fn send_copydone(&mut self) -> Result<(), Error> {
676-
if self.replication_stream_active {
677-
let iclient = self.client.inner();
678-
let mut buf = BytesMut::new();
679-
frontend::copy_done(&mut buf);
680-
iclient
681-
.unpipelined_send(RequestMessages::Single(FrontendMessage::Raw(buf.freeze())))?;
682-
self.replication_stream_active = false;
683-
}
696+
let iclient = self.client.inner();
697+
let mut buf = BytesMut::new();
698+
frontend::copy_done(&mut buf);
699+
iclient.unpipelined_send(RequestMessages::Single(FrontendMessage::Raw(buf.freeze())))?;
700+
684701
Ok(())
685702
}
686703
}
@@ -690,71 +707,76 @@ impl ReplicationClient {
690707
/// [CopyData](postgres_protocol::message::backend::Message::CopyData).
691708
///
692709
/// Intended to be used with the [next()](tokio::stream::StreamExt::next) method.
710+
///
711+
/// If the timeline specified with
712+
/// [start_physical_replication()](ReplicationClient::start_physical_replication)
713+
/// or
714+
/// [start_logical_replication()](ReplicationClient::start_logical_replication())
715+
/// is the current timeline, the stream is indefinite, and must be
716+
/// terminated with
717+
/// [stop_replication()](ReplicationStream::stop_replication()) (which
718+
/// will not return a response tuple); or by dropping the
719+
/// [ReplicationStream](ReplicationStream).
720+
///
721+
/// If the timeline is not the current timeline, the stream will
722+
/// terminate when the end of the timeline is reached, and
723+
/// [stop_replication()](ReplicationStream::stop_replication()) will
724+
/// return a response tuple.
693725
#[pin_project(PinnedDrop)]
694726
pub struct ReplicationStream<'a> {
695727
rclient: &'a mut ReplicationClient,
696728
responses: Responses,
729+
copyboth_received: bool,
730+
copydone_sent: bool,
731+
copydone_received: bool,
732+
replication_response: Option<ReplicationResponse>,
697733
#[pin]
698734
_phantom_pinned: PhantomPinned,
699735
}
700736

701737
impl ReplicationStream<'_> {
702738
/// Stop replication stream and return the replication client object.
703-
pub async fn stop_replication(mut self: Pin<Box<Self>>) -> Result<Option<ReplicationResponse>, Error> {
739+
pub async fn stop_replication(
740+
mut self: Pin<Box<Self>>,
741+
) -> Result<Option<ReplicationResponse>, Error> {
704742
let this = self.as_mut().project();
705743

706-
this.rclient.send_copydone()?;
707-
let responses = this.responses;
708-
709-
// drain remaining CopyData messages and CopyDone
710-
loop {
711-
match responses.next().await? {
712-
Message::CopyData(_) => (),
713-
Message::CopyDone => break,
714-
m => return Err(Error::unexpected_message(m)),
715-
}
744+
if this.replication_response.is_some() {
745+
return Ok(this.replication_response.clone());
716746
}
717747

718-
let next_message = responses.next().await?;
748+
// we must be in copy mode; shut it down
749+
assert!(*this.copyboth_received);
750+
if !*this.copydone_sent {
751+
this.rclient.send_copydone()?;
752+
*this.copydone_sent = true;
753+
}
719754

720-
let response = match next_message {
721-
Message::RowDescription(rowdesc) => {
722-
let datarow = match responses.next().await? {
723-
Message::DataRow(m) => m,
755+
// If server didn't already shut down copy, drain remaining
756+
// CopyData and the CopyDone.
757+
if !*this.copydone_received {
758+
loop {
759+
match this.responses.next().await? {
760+
Message::CopyData(_) => (),
761+
Message::CopyDone => {
762+
*this.copydone_received = true;
763+
break;
764+
}
724765
m => return Err(Error::unexpected_message(m)),
725-
};
726-
727-
let fields = rowdesc.fields().collect::<Vec<_>>().map_err(Error::parse)?;
728-
let ranges = datarow.ranges().collect::<Vec<_>>().map_err(Error::parse)?;
729-
730-
assert_eq!(fields.len(), 2);
731-
assert_eq!(fields[0].type_oid(), Type::INT8.oid());
732-
assert_eq!(fields[0].format(), 0);
733-
assert_eq!(fields[1].type_oid(), Type::TEXT.oid());
734-
assert_eq!(fields[1].format(), 0);
735-
assert_eq!(ranges.len(), 2);
736-
737-
let timeline = &datarow.buffer()[ranges[0].to_owned().unwrap()];
738-
let switch = &datarow.buffer()[ranges[1].to_owned().unwrap()];
739-
Some(ReplicationResponse {
740-
next_tli: from_utf8(timeline).unwrap().parse::<u64>().unwrap(),
741-
next_tli_startpos: Lsn::from(from_utf8(switch).unwrap()),
742-
})
766+
}
743767
}
744-
Message::CommandComplete(_) => None,
745-
m => return Err(Error::unexpected_message(m)),
746-
};
768+
}
747769

748-
match responses.next().await? {
770+
match this.responses.next().await? {
771+
Message::RowDescription(rowdesc) => {
772+
*this.replication_response =
773+
Some(recv_replication_response(this.responses, rowdesc).await?);
774+
}
749775
Message::CommandComplete(_) => (),
750776
m => return Err(Error::unexpected_message(m)),
751-
};
752-
match responses.next().await? {
753-
Message::ReadyForQuery(_) => (),
754-
m => return Err(Error::unexpected_message(m)),
755-
};
777+
}
756778

757-
Ok(response)
779+
Ok(this.replication_response.clone())
758780
}
759781
}
760782

@@ -763,14 +785,27 @@ impl Stream for ReplicationStream<'_> {
763785

764786
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
765787
let this = self.project();
766-
let responses = this.responses;
767788

768-
match ready!(responses.poll_next(cx)?) {
789+
// if we already got a replication response tuple, we're done
790+
if this.replication_response.is_some() {
791+
return Poll::Ready(None);
792+
}
793+
794+
// we are in copy mode
795+
assert!(*this.copyboth_received);
796+
assert!(!*this.copydone_sent);
797+
assert!(!*this.copydone_received);
798+
match ready!(this.responses.poll_next(cx)?) {
769799
Message::CopyData(body) => {
770800
let r = ReplicationMessage::parse(&body.into_bytes());
771801
Poll::Ready(Some(r.map_err(Error::parse)))
772802
}
773-
Message::CopyDone => Poll::Ready(None),
803+
Message::CopyDone => {
804+
*this.copydone_received = true;
805+
this.rclient.send_copydone()?;
806+
*this.copydone_sent = true;
807+
Poll::Ready(None)
808+
}
774809
m => Poll::Ready(Some(Err(Error::unexpected_message(m)))),
775810
}
776811
}
@@ -780,6 +815,39 @@ impl Stream for ReplicationStream<'_> {
780815
impl PinnedDrop for ReplicationStream<'_> {
781816
fn drop(mut self: Pin<&mut Self>) {
782817
let this = self.project();
783-
this.rclient.send_copydone().unwrap();
818+
if *this.copyboth_received && !*this.copydone_sent {
819+
this.rclient.send_copydone().unwrap();
820+
*this.copydone_sent = true;
821+
}
822+
}
823+
}
824+
825+
// Read a replication response tuple from the server. This function
826+
// assumes that the caller has already consumed the RowDescription
827+
// from the stream.
828+
async fn recv_replication_response(
829+
responses: &mut Responses,
830+
rowdesc: RowDescriptionBody,
831+
) -> Result<ReplicationResponse, Error> {
832+
let fields = rowdesc.fields().collect::<Vec<_>>().map_err(Error::parse)?;
833+
assert_eq!(fields.len(), 2);
834+
assert_eq!(fields[0].type_oid(), Type::INT8.oid());
835+
assert_eq!(fields[0].format(), 0);
836+
assert_eq!(fields[1].type_oid(), Type::TEXT.oid());
837+
assert_eq!(fields[1].format(), 0);
838+
839+
match responses.next().await? {
840+
Message::DataRow(datarow) => {
841+
let ranges = datarow.ranges().collect::<Vec<_>>().map_err(Error::parse)?;
842+
assert_eq!(ranges.len(), 2);
843+
844+
let timeline = &datarow.buffer()[ranges[0].to_owned().unwrap()];
845+
let switch = &datarow.buffer()[ranges[1].to_owned().unwrap()];
846+
Ok(ReplicationResponse {
847+
next_tli: from_utf8(timeline).unwrap().parse::<u64>().unwrap(),
848+
next_tli_startpos: Lsn::from(from_utf8(switch).unwrap()),
849+
})
850+
}
851+
m => Err(Error::unexpected_message(m)),
784852
}
785853
}

tokio-postgres/tests/test/replication.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use postgres_protocol::message::backend::ReplicationMessage;
22
use tokio::stream::StreamExt;
3-
use tokio_postgres::Client;
43
use tokio_postgres::replication_client::ReplicationClient;
4+
use tokio_postgres::Client;
55
use tokio_postgres::{connect, connect_replication, NoTls, ReplicationMode};
66

77
const LOGICAL_BEGIN_TAG: u8 = b'B';

0 commit comments

Comments
 (0)