From 7cd61decb7482f11df88cd8d331775256b986a9f Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Fri, 12 May 2023 21:46:06 +0300 Subject: [PATCH 1/6] Add query_raw_txt client method It takes all the extended protocol params as text and passes them to postgres to sort out types. With that we can avoid situations when postgres derived different type compared to what was passed in arguments. There is also propare_typed method, but since we receive data in text format anyway it makes more sense to avoid dealing with types in params. This way we also can save on roundtrip and send Parse+Bind+Describe+Execute right away without waiting for params description before Bind. Also use text protocol for responses -- that allows to grab postgres-provided serializations for types. --- postgres-types/src/lib.rs | 18 ++++++- tokio-postgres/src/client.rs | 85 ++++++++++++++++++++++++++++++- tokio-postgres/src/prepare.rs | 2 +- tokio-postgres/src/query.rs | 11 ++++ tokio-postgres/src/row.rs | 17 +++++++ tokio-postgres/src/statement.rs | 23 +++++++++ tokio-postgres/tests/test/main.rs | 28 ++++++++++ 7 files changed, 181 insertions(+), 3 deletions(-) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index fa49d99eb..f4caa892f 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -395,6 +395,22 @@ impl WrongType { } } +/// An error indicating that a as_text conversion was attempted on a binary +/// result. +#[derive(Debug)] +pub struct WrongFormat {} + +impl Error for WrongFormat {} + +impl fmt::Display for WrongFormat { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + fmt, + "cannot read column as text while it is in binary format" + ) + } +} + /// A trait for types that can be created from a Postgres value. /// /// # Types @@ -846,7 +862,7 @@ pub trait ToSql: fmt::Debug { /// Supported Postgres message format types /// /// Using Text format in a message assumes a Postgres `SERVER_ENCODING` of `UTF8` -#[derive(Clone, Copy, Debug)] +#[derive(Clone, Copy, Debug, PartialEq)] pub enum Format { /// Text format (UTF-8) Text, diff --git a/tokio-postgres/src/client.rs b/tokio-postgres/src/client.rs index eea779f77..37cdd6827 100644 --- a/tokio-postgres/src/client.rs +++ b/tokio-postgres/src/client.rs @@ -7,8 +7,10 @@ use crate::copy_both::CopyBothDuplex; use crate::copy_out::CopyOutStream; #[cfg(feature = "runtime")] use crate::keepalive::KeepaliveConfig; +use crate::prepare::get_type; use crate::query::RowStream; use crate::simple_query::SimpleQueryStream; +use crate::statement::Column; #[cfg(feature = "runtime")] use crate::tls::MakeTlsConnect; use crate::tls::TlsConnect; @@ -20,7 +22,7 @@ use crate::{ CopyInSink, Error, Row, SimpleQueryMessage, Statement, ToStatement, Transaction, TransactionBuilder, }; -use bytes::{Buf, BytesMut}; +use bytes::{Buf, BufMut, BytesMut}; use fallible_iterator::FallibleIterator; use futures_channel::mpsc; use futures_util::{future, pin_mut, ready, StreamExt, TryStreamExt}; @@ -374,6 +376,87 @@ impl Client { query::query(&self.inner, statement, params).await } + /// Pass text directly to the Postgres backend to allow it to sort out typing itself and + /// to save a roundtrip + pub async fn query_raw_txt<'a, S, I>(&self, query: S, params: I) -> Result + where + S: AsRef, + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let params = params.into_iter(); + let params_len = params.len(); + + let buf = self.inner.with_buf(|buf| { + // Parse, anonymous portal + frontend::parse("", query.as_ref(), std::iter::empty(), buf).map_err(Error::encode)?; + // Bind, pass params as text, retrieve as binary + match frontend::bind( + "", // empty string selects the unnamed portal + "", // empty string selects the unnamed prepared statement + std::iter::empty(), // all parameters use the default format (text) + params, + |param, buf| { + buf.put_slice(param.as_ref().as_bytes()); + Ok(postgres_protocol::IsNull::No) + }, + Some(0), // all text + buf, + ) { + Ok(()) => Ok(()), + Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, 0)), + Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)), + }?; + + // Describe portal to typecast results + frontend::describe(b'P', "", buf).map_err(Error::encode)?; + // Execute + frontend::execute("", 0, buf).map_err(Error::encode)?; + // Sync + frontend::sync(buf); + + Ok(buf.split().freeze()) + })?; + + let mut responses = self + .inner + .send(RequestMessages::Single(FrontendMessage::Raw(buf)))?; + + // now read the responses + + match responses.next().await? { + Message::ParseComplete => {} + _ => return Err(Error::unexpected_message()), + } + match responses.next().await? { + Message::BindComplete => {} + _ => return Err(Error::unexpected_message()), + } + let row_description = match responses.next().await? { + Message::RowDescription(body) => Some(body), + Message::NoData => None, + _ => return Err(Error::unexpected_message()), + }; + + // construct statement object + + let parameters = vec![Type::UNKNOWN; params_len]; + + let mut columns = vec![]; + if let Some(row_description) = row_description { + let mut it = row_description.fields(); + while let Some(field) = it.next().map_err(Error::parse)? { + let type_ = get_type(&self.inner, field.type_oid()).await?; + let column = Column::new(field.name().to_string(), type_); + columns.push(column); + } + } + + let statement = Statement::new_text(&self.inner, "".to_owned(), parameters, columns); + + Ok(RowStream::new(statement, responses)) + } + /// Executes a statement, returning the number of rows modified. /// /// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list diff --git a/tokio-postgres/src/prepare.rs b/tokio-postgres/src/prepare.rs index e3f09a7c2..ba8d5a43e 100644 --- a/tokio-postgres/src/prepare.rs +++ b/tokio-postgres/src/prepare.rs @@ -126,7 +126,7 @@ fn encode(client: &InnerClient, name: &str, query: &str, types: &[Type]) -> Resu }) } -async fn get_type(client: &Arc, oid: Oid) -> Result { +pub async fn get_type(client: &Arc, oid: Oid) -> Result { if let Some(type_) = Type::from_oid(oid) { return Ok(type_); } diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index 71db8769a..fa16df9e2 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -207,6 +207,17 @@ pin_project! { } } +impl RowStream { + /// Creates a new `RowStream`. + pub fn new(statement: Statement, responses: Responses) -> Self { + RowStream { + statement, + responses, + _p: PhantomPinned, + } + } +} + impl Stream for RowStream { type Item = Result; diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index db179b432..d5698f806 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -7,6 +7,7 @@ use crate::types::{FromSql, Type, WrongType}; use crate::{Error, Statement}; use fallible_iterator::FallibleIterator; use postgres_protocol::message::backend::DataRowBody; +use postgres_types::{Format, WrongFormat}; use std::fmt; use std::ops::Range; use std::str; @@ -187,6 +188,22 @@ impl Row { let range = self.ranges[idx].to_owned()?; Some(&self.body.buffer()[range]) } + + /// Interpret the column at the given index as text + /// + /// Useful when using query_raw_txt() which sets text transfer mode + pub fn as_text(&self, idx: usize) -> Result, Error> { + if self.statement.output_format() == Format::Text { + match self.col_buffer(idx) { + Some(raw) => { + FromSql::from_sql(&Type::TEXT, raw).map_err(|e| Error::from_sql(e, idx)) + } + None => Ok(None), + } + } else { + Err(Error::from_sql(Box::new(WrongFormat {}), idx)) + } + } } impl AsName for SimpleColumn { diff --git a/tokio-postgres/src/statement.rs b/tokio-postgres/src/statement.rs index 97561a8e4..b7ab11866 100644 --- a/tokio-postgres/src/statement.rs +++ b/tokio-postgres/src/statement.rs @@ -3,6 +3,7 @@ use crate::codec::FrontendMessage; use crate::connection::RequestMessages; use crate::types::Type; use postgres_protocol::message::frontend; +use postgres_types::Format; use std::{ fmt, sync::{Arc, Weak}, @@ -13,6 +14,7 @@ struct StatementInner { name: String, params: Vec, columns: Vec, + output_format: Format, } impl Drop for StatementInner { @@ -46,6 +48,22 @@ impl Statement { name, params, columns, + output_format: Format::Binary, + })) + } + + pub(crate) fn new_text( + inner: &Arc, + name: String, + params: Vec, + columns: Vec, + ) -> Statement { + Statement(Arc::new(StatementInner { + client: Arc::downgrade(inner), + name, + params, + columns, + output_format: Format::Text, })) } @@ -62,6 +80,11 @@ impl Statement { pub fn columns(&self) -> &[Column] { &self.0.columns } + + /// Returns output format for the statement. + pub fn output_format(&self) -> Format { + self.0.output_format + } } /// Information about a column of a query. diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 8de2b75a2..64da95f11 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -251,6 +251,34 @@ async fn custom_array() { } } +#[tokio::test] +async fn query_raw_txt() { + let client = connect("user=postgres").await; + + let rows: Vec = client + .query_raw_txt("SELECT 55 * $1", ["42"]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + let res: i32 = rows[0].as_text(0).unwrap().parse::().unwrap(); + assert_eq!(res, 55 * 42); + + let rows: Vec = client + .query_raw_txt("SELECT $1", ["42"]) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].get::<_, &str>(0), "42"); +} + #[tokio::test] async fn custom_composite() { let client = connect("user=postgres").await; From b114de337208b74cae157f9b850ad4c676f220a2 Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Fri, 19 May 2023 02:07:48 +0300 Subject: [PATCH 2/6] Catch command tag --- tokio-postgres/src/query.rs | 22 +++++++++++++++++++--- tokio-postgres/tests/test/main.rs | 21 ++++++++++++++++++++- 2 files changed, 39 insertions(+), 4 deletions(-) diff --git a/tokio-postgres/src/query.rs b/tokio-postgres/src/query.rs index fa16df9e2..a486b4f88 100644 --- a/tokio-postgres/src/query.rs +++ b/tokio-postgres/src/query.rs @@ -52,6 +52,7 @@ where Ok(RowStream { statement, responses, + command_tag: None, _p: PhantomPinned, }) } @@ -72,6 +73,7 @@ pub async fn query_portal( Ok(RowStream { statement: portal.statement().clone(), responses, + command_tag: None, _p: PhantomPinned, }) } @@ -202,6 +204,7 @@ pin_project! { pub struct RowStream { statement: Statement, responses: Responses, + command_tag: Option, #[pin] _p: PhantomPinned, } @@ -213,6 +216,7 @@ impl RowStream { RowStream { statement, responses, + command_tag: None, _p: PhantomPinned, } } @@ -228,12 +232,24 @@ impl Stream for RowStream { Message::DataRow(body) => { return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?))) } - Message::EmptyQueryResponse - | Message::CommandComplete(_) - | Message::PortalSuspended => {} + Message::EmptyQueryResponse | Message::PortalSuspended => {} + Message::CommandComplete(body) => { + if let Ok(tag) = body.tag() { + *this.command_tag = Some(tag.to_string()); + } + } Message::ReadyForQuery(_) => return Poll::Ready(None), _ => return Poll::Ready(Some(Err(Error::unexpected_message()))), } } } } + +impl RowStream { + /// Returns the command tag of this query. + /// + /// This is only available after the stream has been exhausted. + pub fn command_tag(&self) -> Option { + self.command_tag.clone() + } +} diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 64da95f11..aa0eb1652 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -264,7 +264,7 @@ async fn query_raw_txt() { .unwrap(); assert_eq!(rows.len(), 1); - let res: i32 = rows[0].as_text(0).unwrap().parse::().unwrap(); + let res: i32 = rows[0].as_text(0).unwrap().unwrap().parse::().unwrap(); assert_eq!(res, 55 * 42); let rows: Vec = client @@ -279,6 +279,25 @@ async fn query_raw_txt() { assert_eq!(rows[0].get::<_, &str>(0), "42"); } +#[tokio::test] +async fn command_tag() { + let client = connect("user=postgres").await; + + let row_stream = client + .query_raw_txt("select unnest('{1,2,3}'::int[]);", []) + .await + .unwrap(); + + pin_mut!(row_stream); + + let mut rows: Vec = Vec::new(); + while let Some(row) = row_stream.next().await { + rows.push(row.unwrap()); + } + + assert_eq!(row_stream.command_tag(), Some("SELECT 3".to_string())); +} + #[tokio::test] async fn custom_composite() { let client = connect("user=postgres").await; From f898c2e91d3f7cc849c28fd1db768595ddca2ee6 Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Fri, 19 May 2023 02:12:54 +0300 Subject: [PATCH 3/6] Bump rust version in CI and fix clippy warnings --- .github/workflows/ci.yml | 2 +- postgres-derive-test/src/lib.rs | 4 ++-- postgres-protocol/src/authentication/sasl.rs | 2 +- postgres-protocol/src/message/backend.rs | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0e56ca84d..1ca030d26 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -57,7 +57,7 @@ jobs: - run: docker compose up -d - uses: sfackler/actions/rustup@master with: - version: 1.62.0 + version: 1.65.0 - run: echo "::set-output name=version::$(rustc --version)" id: rust-version - uses: actions/cache@v1 diff --git a/postgres-derive-test/src/lib.rs b/postgres-derive-test/src/lib.rs index d1478ac4c..f0534f32c 100644 --- a/postgres-derive-test/src/lib.rs +++ b/postgres-derive-test/src/lib.rs @@ -14,7 +14,7 @@ where T: PartialEq + FromSqlOwned + ToSql + Sync, S: fmt::Display, { - for &(ref val, ref repr) in checks.iter() { + for (val, repr) in checks.iter() { let stmt = conn .prepare(&format!("SELECT {}::{}", *repr, sql_type)) .unwrap(); @@ -38,7 +38,7 @@ pub fn test_type_asymmetric( S: fmt::Display, C: Fn(&T, &F) -> bool, { - for &(ref val, ref repr) in checks.iter() { + for (val, repr) in checks.iter() { let stmt = conn .prepare(&format!("SELECT {}::{}", *repr, sql_type)) .unwrap(); diff --git a/postgres-protocol/src/authentication/sasl.rs b/postgres-protocol/src/authentication/sasl.rs index fdb88114a..41d0e41b0 100644 --- a/postgres-protocol/src/authentication/sasl.rs +++ b/postgres-protocol/src/authentication/sasl.rs @@ -389,7 +389,7 @@ impl<'a> Parser<'a> { } fn posit_number(&mut self) -> io::Result { - let n = self.take_while(|c| matches!(c, '0'..='9'))?; + let n = self.take_while(|c| c.is_ascii_digit())?; n.parse() .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e)) } diff --git a/postgres-protocol/src/message/backend.rs b/postgres-protocol/src/message/backend.rs index 9aa46588e..b6883cc3c 100644 --- a/postgres-protocol/src/message/backend.rs +++ b/postgres-protocol/src/message/backend.rs @@ -707,7 +707,7 @@ impl<'a> FallibleIterator for DataRowRanges<'a> { )); } let base = self.len - self.buf.len(); - self.buf = &self.buf[len as usize..]; + self.buf = &self.buf[len..]; Ok(Some(Some(base..base + len))) } } From 5bc7d0e482bb5c4ee59807ff4a657954a95181a0 Mon Sep 17 00:00:00 2001 From: Stas Kelvich Date: Fri, 19 May 2023 21:12:16 +0300 Subject: [PATCH 4/6] expose row buffer size --- tokio-postgres/src/row.rs | 5 +++++ tokio-postgres/tests/test/main.rs | 1 + 2 files changed, 6 insertions(+) diff --git a/tokio-postgres/src/row.rs b/tokio-postgres/src/row.rs index d5698f806..ce4efed7e 100644 --- a/tokio-postgres/src/row.rs +++ b/tokio-postgres/src/row.rs @@ -204,6 +204,11 @@ impl Row { Err(Error::from_sql(Box::new(WrongFormat {}), idx)) } } + + /// Row byte size + pub fn body_len(&self) -> usize { + self.body.buffer().len() + } } impl AsName for SimpleColumn { diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index aa0eb1652..3213a6dad 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -277,6 +277,7 @@ async fn query_raw_txt() { assert_eq!(rows.len(), 1); assert_eq!(rows[0].get::<_, &str>(0), "42"); + assert!(rows[0].body_len() > 0); } #[tokio::test] From 4f780bf34e3168ee352dcd3b18090abb6906a67a Mon Sep 17 00:00:00 2001 From: Arthur Petukhovsky Date: Mon, 22 May 2023 16:54:01 +0000 Subject: [PATCH 5/6] Limit backend messages length to prevent DoS --- tokio-postgres/src/codec.rs | 13 ++++++++++++- tokio-postgres/src/config.rs | 13 +++++++++++++ tokio-postgres/src/connect_raw.rs | 4 +++- 3 files changed, 28 insertions(+), 2 deletions(-) diff --git a/tokio-postgres/src/codec.rs b/tokio-postgres/src/codec.rs index 9d078044b..23c371542 100644 --- a/tokio-postgres/src/codec.rs +++ b/tokio-postgres/src/codec.rs @@ -35,7 +35,9 @@ impl FallibleIterator for BackendMessages { } } -pub struct PostgresCodec; +pub struct PostgresCodec { + pub max_message_size: Option, +} impl Encoder for PostgresCodec { type Error = io::Error; @@ -64,6 +66,15 @@ impl Decoder for PostgresCodec { break; } + if let Some(max) = self.max_message_size { + if len > max { + return Err(io::Error::new( + io::ErrorKind::InvalidInput, + "message too large", + )); + } + } + match header.tag() { backend::NOTICE_RESPONSE_TAG | backend::NOTIFICATION_RESPONSE_TAG diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 4153fa250..3783e40b2 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -185,6 +185,7 @@ pub struct Config { pub(crate) target_session_attrs: TargetSessionAttrs, pub(crate) channel_binding: ChannelBinding, pub(crate) replication_mode: Option, + pub(crate) max_backend_message_size: Option, } impl Default for Config { @@ -217,6 +218,7 @@ impl Config { target_session_attrs: TargetSessionAttrs::Any, channel_binding: ChannelBinding::Prefer, replication_mode: None, + max_backend_message_size: None, } } @@ -472,6 +474,17 @@ impl Config { self.replication_mode } + /// Set limit for backend messages size. + pub fn max_backend_message_size(&mut self, max_backend_message_size: usize) -> &mut Config { + self.max_backend_message_size = Some(max_backend_message_size); + self + } + + /// Get limit for backend messages size. + pub fn get_max_backend_message_size(&self) -> Option { + self.max_backend_message_size + } + fn param(&mut self, key: &str, value: &str) -> Result<(), Error> { match key { "user" => { diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index ddfca2894..782b14959 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -90,7 +90,9 @@ where let stream = connect_tls(stream, config.ssl_mode, tls).await?; let mut stream = StartupStream { - inner: Framed::new(stream, PostgresCodec), + inner: Framed::new(stream, PostgresCodec { + max_message_size: config.max_backend_message_size, + }), buf: BackendMessages::empty(), delayed: VecDeque::new(), }; From 2c1ad15da2b4cd011b5f08f1b34bbe2e16d89236 Mon Sep 17 00:00:00 2001 From: Arthur Petukhovsky Date: Mon, 22 May 2023 17:42:51 +0000 Subject: [PATCH 6/6] Add test for max_backend_message_size --- tokio-postgres/src/config.rs | 8 ++++++++ tokio-postgres/src/connect_raw.rs | 9 ++++++--- tokio-postgres/tests/test/main.rs | 24 ++++++++++++++++++++++++ 3 files changed, 38 insertions(+), 3 deletions(-) diff --git a/tokio-postgres/src/config.rs b/tokio-postgres/src/config.rs index 3783e40b2..fdb5e6359 100644 --- a/tokio-postgres/src/config.rs +++ b/tokio-postgres/src/config.rs @@ -599,6 +599,14 @@ impl Config { self.replication_mode(mode); } } + "max_backend_message_size" => { + let limit = value.parse::().map_err(|_| { + Error::config_parse(Box::new(InvalidValue("max_backend_message_size"))) + })?; + if limit > 0 { + self.max_backend_message_size(limit); + } + } key => { return Err(Error::config_parse(Box::new(UnknownOption( key.to_string(), diff --git a/tokio-postgres/src/connect_raw.rs b/tokio-postgres/src/connect_raw.rs index 782b14959..0beead11f 100644 --- a/tokio-postgres/src/connect_raw.rs +++ b/tokio-postgres/src/connect_raw.rs @@ -90,9 +90,12 @@ where let stream = connect_tls(stream, config.ssl_mode, tls).await?; let mut stream = StartupStream { - inner: Framed::new(stream, PostgresCodec { - max_message_size: config.max_backend_message_size, - }), + inner: Framed::new( + stream, + PostgresCodec { + max_message_size: config.max_backend_message_size, + }, + ), buf: BackendMessages::empty(), delayed: VecDeque::new(), }; diff --git a/tokio-postgres/tests/test/main.rs b/tokio-postgres/tests/test/main.rs index 3213a6dad..551f6ec5c 100644 --- a/tokio-postgres/tests/test/main.rs +++ b/tokio-postgres/tests/test/main.rs @@ -280,6 +280,30 @@ async fn query_raw_txt() { assert!(rows[0].body_len() > 0); } +#[tokio::test] +async fn limit_max_backend_message_size() { + let client = connect("user=postgres max_backend_message_size=10000").await; + let small: Vec = client + .query_raw_txt("SELECT REPEAT('a', 20)", []) + .await + .unwrap() + .try_collect() + .await + .unwrap(); + + assert_eq!(small.len(), 1); + assert_eq!(small[0].as_text(0).unwrap().unwrap().len(), 20); + + let large: Result, Error> = client + .query_raw_txt("SELECT REPEAT('a', 2000000)", []) + .await + .unwrap() + .try_collect() + .await; + + assert!(large.is_err()); +} + #[tokio::test] async fn command_tag() { let client = connect("user=postgres").await;