diff --git a/sqld/src/connection/libsql.rs b/sqld/src/connection/libsql.rs index 1a492602..40d458da 100644 --- a/sqld/src/connection/libsql.rs +++ b/sqld/src/connection/libsql.rs @@ -531,7 +531,7 @@ impl Connection { let blocked = match query.stmt.kind { StmtKind::Read | StmtKind::TxnBegin | StmtKind::Other => config.block_reads, StmtKind::Write => config.block_reads || config.block_writes, - StmtKind::TxnEnd => false, + StmtKind::TxnEnd | StmtKind::Release | StmtKind::Savepoint => false, }; if blocked { return Err(Error::Blocked(config.block_reason.clone())); diff --git a/sqld/src/http/user/mod.rs b/sqld/src/http/user/mod.rs index 367f7e77..6e42df56 100644 --- a/sqld/src/http/user/mod.rs +++ b/sqld/src/http/user/mod.rs @@ -97,6 +97,17 @@ fn parse_queries(queries: Vec) -> crate::Result> { out.push(query); } + // It's too complicated to predict the state of a transaction with savepoints in legacy http, + // forbid them instead. + if out + .iter() + .any(|q| q.stmt.kind.is_release() || q.stmt.kind.is_release()) + { + return Err(Error::QueryError( + "savepoints are not supported in HTTP API, use hrana protocol instead".to_string(), + )); + } + match predict_final_state(State::Init, out.iter().map(|q| &q.stmt)) { State::Txn => { return Err(Error::QueryError( diff --git a/sqld/src/query_analysis.rs b/sqld/src/query_analysis.rs index 5f0e4f37..a592376f 100644 --- a/sqld/src/query_analysis.rs +++ b/sqld/src/query_analysis.rs @@ -28,6 +28,8 @@ pub enum StmtKind { TxnEnd, Read, Write, + Savepoint, + Release, Other, } @@ -51,7 +53,13 @@ impl StmtKind { Cmd::Explain(_) => Some(Self::Other), Cmd::ExplainQueryPlan(_) => Some(Self::Other), Cmd::Stmt(Stmt::Begin { .. }) => Some(Self::TxnBegin), - Cmd::Stmt(Stmt::Commit { .. } | Stmt::Rollback { .. }) => Some(Self::TxnEnd), + Cmd::Stmt( + Stmt::Commit { .. } + | Stmt::Rollback { + savepoint_name: None, + .. + }, + ) => Some(Self::TxnEnd), Cmd::Stmt( Stmt::CreateVirtualTable { tbl_name, .. } | Stmt::CreateTable { @@ -99,6 +107,12 @@ impl StmtKind { temporary: false, .. }) => Some(Self::Write), Cmd::Stmt(Stmt::DropView { .. }) => Some(Self::Write), + Cmd::Stmt(Stmt::Savepoint(_)) => Some(Self::Savepoint), + Cmd::Stmt(Stmt::Release(_)) + | Cmd::Stmt(Stmt::Rollback { + savepoint_name: Some(_), + .. + }) => Some(Self::Release), _ => None, } } @@ -167,6 +181,22 @@ impl StmtKind { }, } } + + /// Returns `true` if the stmt kind is [`Savepoint`]. + /// + /// [`Savepoint`]: StmtKind::Savepoint + #[must_use] + pub fn is_savepoint(&self) -> bool { + matches!(self, Self::Savepoint) + } + + /// Returns `true` if the stmt kind is [`Release`]. + /// + /// [`Release`]: StmtKind::Release + #[must_use] + pub fn is_release(&self) -> bool { + matches!(self, Self::Release) + } } /// The state of a transaction for a series of statement @@ -188,6 +218,7 @@ impl State { (state, StmtKind::Other | StmtKind::Write | StmtKind::Read) => state, (State::Invalid, _) => State::Invalid, (State::Init, StmtKind::TxnBegin) => State::Txn, + _ => State::Invalid, }; } @@ -276,7 +307,11 @@ impl Statement { pub fn is_read_only(&self) -> bool { matches!( self.kind, - StmtKind::Read | StmtKind::TxnEnd | StmtKind::TxnBegin + StmtKind::Read + | StmtKind::TxnEnd + | StmtKind::TxnBegin + | StmtKind::Release + | StmtKind::Savepoint ) } }