|
| 1 | +use crate::replication::LogReadError; |
| 2 | +use crate::replication::{frame::Frame, primary::frame_stream::FrameStream, ReplicationLogger}; |
| 3 | +use anyhow::{Context, Result}; |
| 4 | +use axum::extract::State; |
| 5 | +use hyper::{Body, Response}; |
| 6 | +use std::net::SocketAddr; |
| 7 | +use std::sync::Arc; |
| 8 | + |
| 9 | +#[derive(Debug, serde::Deserialize, serde::Serialize)] |
| 10 | +pub struct FramesRequest { |
| 11 | + pub next_offset: u64, |
| 12 | +} |
| 13 | + |
| 14 | +#[derive(Debug, serde::Deserialize, serde::Serialize)] |
| 15 | +pub struct Frames { |
| 16 | + pub frames: Vec<Frame>, |
| 17 | +} |
| 18 | + |
| 19 | +#[derive(Debug, serde::Deserialize, serde::Serialize)] |
| 20 | +pub struct Hello { |
| 21 | + pub generation_id: uuid::Uuid, |
| 22 | + pub generation_start_index: u64, |
| 23 | + pub database_id: uuid::Uuid, |
| 24 | +} |
| 25 | + |
| 26 | +// Thin wrapper to allow returning anyhow errors from axum |
| 27 | +struct AppError(anyhow::Error); |
| 28 | + |
| 29 | +impl axum::response::IntoResponse for AppError { |
| 30 | + fn into_response(self) -> axum::response::Response { |
| 31 | + ( |
| 32 | + hyper::StatusCode::INTERNAL_SERVER_ERROR, |
| 33 | + format!("Replication failed: {}", self.0), |
| 34 | + ) |
| 35 | + .into_response() |
| 36 | + } |
| 37 | +} |
| 38 | + |
| 39 | +impl<E: Into<anyhow::Error>> From<E> for AppError { |
| 40 | + fn from(err: E) -> Self { |
| 41 | + Self(err.into()) |
| 42 | + } |
| 43 | +} |
| 44 | + |
| 45 | +pub async fn run(addr: SocketAddr, logger: Arc<ReplicationLogger>) -> Result<()> { |
| 46 | + use axum::routing::{get, post}; |
| 47 | + let router = axum::Router::new() |
| 48 | + .route("/hello", get(handle_hello)) |
| 49 | + .route("/frames", post(handle_frames)) |
| 50 | + .with_state(logger); |
| 51 | + |
| 52 | + let server = hyper::Server::try_bind(&addr) |
| 53 | + .context("Could not bind admin HTTP API server")? |
| 54 | + .serve(router.into_make_service()); |
| 55 | + |
| 56 | + tracing::info!( |
| 57 | + "Listening for replication HTTP API requests on {}", |
| 58 | + server.local_addr() |
| 59 | + ); |
| 60 | + server.await?; |
| 61 | + Ok(()) |
| 62 | +} |
| 63 | + |
| 64 | +impl Frames { |
| 65 | + pub fn new() -> Self { |
| 66 | + Self { frames: Vec::new() } |
| 67 | + } |
| 68 | + |
| 69 | + pub fn push(&mut self, frame: Frame) { |
| 70 | + self.frames.push(frame); |
| 71 | + } |
| 72 | + |
| 73 | + pub fn is_empty(&self) -> bool { |
| 74 | + self.frames.is_empty() |
| 75 | + } |
| 76 | +} |
| 77 | + |
| 78 | +async fn handle_hello( |
| 79 | + State(logger): State<Arc<ReplicationLogger>>, |
| 80 | +) -> std::result::Result<Response<Body>, AppError> { |
| 81 | + let hello = Hello { |
| 82 | + generation_id: logger.generation.id, |
| 83 | + generation_start_index: logger.generation.start_index, |
| 84 | + database_id: logger.database_id()?, |
| 85 | + }; |
| 86 | + |
| 87 | + let resp = Response::builder() |
| 88 | + .status(hyper::StatusCode::OK) |
| 89 | + .body(Body::from(serde_json::to_vec(&hello)?)) |
| 90 | + .unwrap(); |
| 91 | + Ok(resp) |
| 92 | +} |
| 93 | + |
| 94 | +fn error(msg: &str, code: hyper::StatusCode) -> Response<Body> { |
| 95 | + let err = serde_json::json!({ "error": msg }); |
| 96 | + Response::builder() |
| 97 | + .status(code) |
| 98 | + .body(Body::from(serde_json::to_vec(&err).unwrap())) |
| 99 | + .unwrap() |
| 100 | +} |
| 101 | + |
| 102 | +async fn handle_frames( |
| 103 | + State(logger): State<Arc<ReplicationLogger>>, |
| 104 | + req: String, // it's a JSON, but Axum errors-out if Content-Type isn't set to json, which is too strict |
| 105 | +) -> std::result::Result<Response<Body>, AppError> { |
| 106 | + const MAX_FRAMES_IN_SINGLE_RESPONSE: usize = 256; |
| 107 | + |
| 108 | + let FramesRequest { next_offset } = match serde_json::from_str(&req) { |
| 109 | + Ok(req) => req, |
| 110 | + Err(resp) => return Ok(error(&resp.to_string(), hyper::StatusCode::BAD_REQUEST)), |
| 111 | + }; |
| 112 | + tracing::trace!("Requested next offset: {next_offset}"); |
| 113 | + |
| 114 | + let next_offset = std::cmp::max(next_offset, 1); // Frames start from 1 |
| 115 | + let current_frameno = next_offset - 1; |
| 116 | + let mut frame_stream = FrameStream::new(logger.clone(), current_frameno); |
| 117 | + tracing::trace!( |
| 118 | + "Max available frame_no: {}", |
| 119 | + frame_stream.max_available_frame_no |
| 120 | + ); |
| 121 | + if frame_stream.max_available_frame_no < next_offset { |
| 122 | + tracing::trace!("No frames available starting {next_offset}, returning 204 No Content"); |
| 123 | + return Ok(Response::builder() |
| 124 | + .status(hyper::StatusCode::NO_CONTENT) |
| 125 | + .body(Body::empty())?); |
| 126 | + } |
| 127 | + |
| 128 | + let mut frames = Frames::new(); |
| 129 | + for _ in 0..MAX_FRAMES_IN_SINGLE_RESPONSE { |
| 130 | + use futures::StreamExt; |
| 131 | + |
| 132 | + match frame_stream.next().await { |
| 133 | + Some(Ok(frame)) => { |
| 134 | + tracing::trace!("Read frame {}", frame_stream.current_frame_no); |
| 135 | + frames.push(frame); |
| 136 | + } |
| 137 | + Some(Err(LogReadError::SnapshotRequired)) => { |
| 138 | + drop(frame_stream); |
| 139 | + if frames.is_empty() { |
| 140 | + tracing::debug!("Snapshot required, switching to snapshot mode"); |
| 141 | + frames = load_snapshot(logger, next_offset)?; |
| 142 | + } else { |
| 143 | + tracing::debug!("Snapshot required, but some frames were read - returning."); |
| 144 | + } |
| 145 | + break; |
| 146 | + } |
| 147 | + Some(Err(e)) => { |
| 148 | + tracing::error!("Error reading frame: {}", e); |
| 149 | + return Ok(Response::builder() |
| 150 | + .status(hyper::StatusCode::INTERNAL_SERVER_ERROR) |
| 151 | + .body(Body::empty()) |
| 152 | + .unwrap()); |
| 153 | + } |
| 154 | + None => break, |
| 155 | + } |
| 156 | + |
| 157 | + if frame_stream.max_available_frame_no <= frame_stream.current_frame_no { |
| 158 | + break; |
| 159 | + } |
| 160 | + } |
| 161 | + |
| 162 | + if frames.is_empty() { |
| 163 | + return Ok(Response::builder() |
| 164 | + .status(hyper::StatusCode::NO_CONTENT) |
| 165 | + .body(Body::empty()) |
| 166 | + .unwrap()); |
| 167 | + } |
| 168 | + |
| 169 | + Ok(Response::builder() |
| 170 | + .status(hyper::StatusCode::OK) |
| 171 | + .body(Body::from(serde_json::to_string(&frames)?)) |
| 172 | + .unwrap()) |
| 173 | +} |
| 174 | + |
| 175 | +// FIXME: In the HTTP stateless spirit, we just unconditionally send the whole snapshot |
| 176 | +// here, which is an obvious overcommit. We should instead stream in smaller parts |
| 177 | +// if the snapshot is large. |
| 178 | +fn load_snapshot(logger: Arc<ReplicationLogger>, from: u64) -> Result<Frames> { |
| 179 | + let snapshot = match logger.get_snapshot_file(from) { |
| 180 | + Ok(Some(snapshot)) => snapshot, |
| 181 | + _ => { |
| 182 | + tracing::trace!("No snapshot available, returning no frames"); |
| 183 | + return Ok(Frames { frames: Vec::new() }); |
| 184 | + } |
| 185 | + }; |
| 186 | + let mut frames = Frames::new(); |
| 187 | + for bytes in snapshot.frames_iter_from(from) { |
| 188 | + frames.push(Frame::try_from_bytes(bytes?)?); |
| 189 | + } |
| 190 | + tracing::trace!( |
| 191 | + "Loaded {} frames from the snapshot file", |
| 192 | + frames.frames.len() |
| 193 | + ); |
| 194 | + Ok(frames) |
| 195 | +} |
0 commit comments