diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8811607f1d6..1d119ba8552 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -844,8 +844,8 @@ jobs: # the same name (we only want to document those anyway) cargo doc --no-deps --lib -p mithril-stm -p mithril-common \ -p mithril-cardano-node-chain -p mithril-cardano-node-internal-database \ - -p mithril-dmq \ - -p mithril-build-script -p mithril-cli-helper -p mithril-doc -p mithril-doc-derive \ + -p mithril-aggregator-client -p mithril-build-script -p mithril-cli-helper \ + -p mithril-dmq -p mithril-doc -p mithril-doc-derive \ -p mithril-era -p mithril-metric -p mithril-persistence -p mithril-resource-pool \ -p mithril-ticker -p mithril-signed-entity-lock -p mithril-signed-entity-preloader \ -p mithril-aggregator -p mithril-signer -p mithril-client -p mithril-client-cli \ diff --git a/Cargo.lock b/Cargo.lock index 7481cb5c561..df1d6fb9bc0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3928,7 +3928,7 @@ dependencies = [ [[package]] name = "mithril-aggregator" -version = "0.7.77" +version = "0.7.78" dependencies = [ "anyhow", "async-trait", @@ -3985,6 +3985,27 @@ dependencies = [ "zstd", ] +[[package]] +name = "mithril-aggregator-client" +version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "http 1.3.1", + "httpmock", + "mithril-common", + "mockall", + "reqwest", + "semver", + "serde", + "serde_json", + "slog", + "slog-async", + "slog-term", + "thiserror 2.0.12", + "tokio", +] + [[package]] name = "mithril-aggregator-fake" version = "0.4.12" @@ -4180,7 +4201,7 @@ dependencies = [ [[package]] name = "mithril-common" -version = "0.6.12" +version = "0.6.13" dependencies = [ "anyhow", "async-trait", @@ -4390,7 +4411,7 @@ dependencies = [ [[package]] name = "mithril-signer" -version = "0.2.262" +version = "0.2.263" dependencies = [ "anyhow", "async-trait", diff --git a/Cargo.toml b/Cargo.toml index 0000bfad631..59c66526467 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ members = [ "examples/client-mithril-stake-distribution", "internal/cardano-node/mithril-cardano-node-chain", "internal/cardano-node/mithril-cardano-node-internal-database", + "internal/mithril-aggregator-client", "internal/mithril-build-script", "internal/mithril-cli-helper", "internal/mithril-dmq", diff --git a/README.md b/README.md index c03f744d740..920bd1a3ec2 100644 --- a/README.md +++ b/README.md @@ -80,11 +80,13 @@ This repository consists of the following parts: - [**Mithril signer**](./mithril-signer): the node of the **Mithril network** responsible for producing individual signatures that are collected and aggregated by the **Mithril aggregator**. - [**Internal**](./internal): the shared tools and API used by **Mithril** crates. + - [**Mithril aggregator client**](./internal/mithril-aggregator-client): a client to request data from a Mithril Aggregator, used by **Mithril network** nodes and client library. + - [**Mithril build script**](./internal/mithril-build-script): a toolbox for Mithril crates that uses a build script phase. - [**Mithril cardano-node-chain**](./internal/cardano-node/mithril-cardano-node-chain): mechanisms to read and interact with the **Cardano chain** through a Cardano node, used by **Mithril network** nodes. - - [**Mithril cardano-node-internal-database**](./internal/cardano-node/mithril-cardano-node-internal-database): mechanisms to read the files of a **Cardano node** internal database and compute digests from them, used by **Mithril network** nodes. + - [**Mithril cardano-node-internal-database**](./internal/cardano-node/mithril-cardano-node-internal-database): mechanisms to read the files of a **Cardano node** internal database and compute digests from them, used by **Mithril network** nodes and client library. - [**Mithril cli helper**](./internal/mithril-cli-helper): **CLI** tools for **Mithril** binaries. diff --git a/internal/mithril-aggregator-client/Cargo.toml b/internal/mithril-aggregator-client/Cargo.toml new file mode 100644 index 00000000000..437c9ca7e6c --- /dev/null +++ b/internal/mithril-aggregator-client/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "mithril-aggregator-client" +version = "0.1.0" +description = "Client to request data from a Mithril Aggregator" +authors.workspace = true +documentation.workspace = true +edition.workspace = true +homepage.workspace = true +license.workspace = true +repository.workspace = true +include = ["**/*.rs", "Cargo.toml", "README.md"] + +[lib] +crate-type = ["lib", "cdylib", "staticlib"] + +[dependencies] +anyhow = { workspace = true } +async-trait = { workspace = true } +mithril-common = { path = "../../mithril-common", version = ">=0.5" } +reqwest = { workspace = true } +semver = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +slog = { workspace = true } +thiserror = { workspace = true } +tokio = { workspace = true } + +[dev-dependencies] +http = "1.3.1" +httpmock = "0.7.0" +mithril-common = { path = "../../mithril-common", version = ">=0.5", features = ["test_tools"] } +mockall = { workspace = true } +slog-async = { workspace = true } +slog-term = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } diff --git a/internal/mithril-aggregator-client/Makefile b/internal/mithril-aggregator-client/Makefile new file mode 100644 index 00000000000..d66d6d9fefc --- /dev/null +++ b/internal/mithril-aggregator-client/Makefile @@ -0,0 +1,19 @@ +.PHONY: all build test check doc + +CARGO = cargo + +all: test build + +build: + ${CARGO} build --release + +test: + ${CARGO} test + +check: + ${CARGO} check --release --all-features --all-targets + ${CARGO} clippy --release --all-features --all-targets + ${CARGO} fmt --check + +doc: + ${CARGO} doc --no-deps --open diff --git a/internal/mithril-aggregator-client/README.md b/internal/mithril-aggregator-client/README.md new file mode 100644 index 00000000000..f9a673353ec --- /dev/null +++ b/internal/mithril-aggregator-client/README.md @@ -0,0 +1,3 @@ +# Mithril-aggregator-client [![CI workflow](https://github.com/input-output-hk/mithril/actions/workflows/ci.yml/badge.svg)](https://github.com/input-output-hk/mithril/actions/workflows/ci.yml) [![License](https://img.shields.io/badge/license-Apache%202.0-blue?style=flat-square)](https://github.com/input-output-hk/mithril/blob/main/LICENSE) [![Discord](https://img.shields.io/discord/500028886025895936.svg?logo=discord&style=flat-square)](https://discord.gg/5kaErDKDRq) + +This crate provides a client to request data from a Mithril Aggregator. diff --git a/internal/mithril-aggregator-client/src/builder.rs b/internal/mithril-aggregator-client/src/builder.rs new file mode 100644 index 00000000000..fe464a73d69 --- /dev/null +++ b/internal/mithril-aggregator-client/src/builder.rs @@ -0,0 +1,129 @@ +use anyhow::Context; +use reqwest::{Client, IntoUrl, Proxy, Url}; +use slog::{Logger, o}; +use std::collections::HashMap; +use std::time::Duration; + +use mithril_common::StdResult; +use mithril_common::api_version::APIVersionProvider; + +use crate::client::AggregatorClient; + +/// A builder of [AggregatorClient] +pub struct AggregatorClientBuilder { + aggregator_url_result: reqwest::Result, + api_version_provider: Option, + additional_headers: Option>, + timeout_duration: Option, + relay_endpoint: Option, + logger: Option, +} + +impl AggregatorClientBuilder { + /// Constructs a new `AggregatorClientBuilder`. + // + // This is the same as `AggregatorClient::builder()`. + pub fn new(aggregator_url: U) -> Self { + Self { + aggregator_url_result: aggregator_url.into_url(), + api_version_provider: None, + additional_headers: None, + timeout_duration: None, + relay_endpoint: None, + logger: None, + } + } + + /// Set the [Logger] to use. + pub fn with_logger(mut self, logger: Logger) -> Self { + self.logger = Some(logger); + self + } + + /// Set the [APIVersionProvider] to use. + pub fn with_api_version_provider(mut self, api_version_provider: APIVersionProvider) -> Self { + self.api_version_provider = Some(api_version_provider); + self + } + + /// Set a timeout to enforce on each request + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout_duration = Some(timeout); + self + } + + /// Add a set of http headers that will be sent on client requests + pub fn with_headers(mut self, custom_headers: HashMap) -> Self { + self.additional_headers = Some(custom_headers); + self + } + + /// Set the address of the relay + pub fn with_relay_endpoint(mut self, relay_endpoint: String) -> Self { + self.relay_endpoint = Some(relay_endpoint); + self + } + + /// Returns an [AggregatorClient] based on the builder configuration + pub fn build(self) -> StdResult { + let aggregator_endpoint = + enforce_trailing_slash(self.aggregator_url_result.with_context( + || "Invalid aggregator endpoint, it must be a correctly formed url", + )?); + let logger = self.logger.unwrap_or_else(|| Logger::root(slog::Discard, o!())); + let api_version_provider = self.api_version_provider.unwrap_or_default(); + let additional_headers = self.additional_headers.unwrap_or_default(); + let mut client_builder = Client::builder(); + + if let Some(relay_endpoint) = self.relay_endpoint { + client_builder = client_builder + .proxy(Proxy::all(relay_endpoint).with_context(|| "Relay proxy creation failed")?) + } + + Ok(AggregatorClient { + aggregator_endpoint, + api_version_provider, + additional_headers: (&additional_headers) + .try_into() + .with_context(|| format!("Invalid headers: '{additional_headers:?}'"))?, + timeout_duration: self.timeout_duration, + client: client_builder + .build() + .with_context(|| "HTTP client creation failed")?, + logger, + }) + } +} + +fn enforce_trailing_slash(url: Url) -> Url { + // Trailing slash is significant because url::join + // (https://docs.rs/url/latest/url/struct.Url.html#method.join) will remove + // the 'path' part of the url if it doesn't end with a trailing slash. + if url.as_str().ends_with('/') { + url + } else { + let mut url = url.clone(); + url.set_path(&format!("{}/", url.path())); + url + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn enforce_trailing_slash_for_aggregator_url() { + let url_without_trailing_slash = Url::parse("http://localhost:8080").unwrap(); + let url_with_trailing_slash = Url::parse("http://localhost:8080/").unwrap(); + + assert_eq!( + url_with_trailing_slash, + enforce_trailing_slash(url_without_trailing_slash.clone()) + ); + assert_eq!( + url_with_trailing_slash, + enforce_trailing_slash(url_with_trailing_slash.clone()) + ); + } +} diff --git a/internal/mithril-aggregator-client/src/client.rs b/internal/mithril-aggregator-client/src/client.rs new file mode 100644 index 00000000000..04812231ccb --- /dev/null +++ b/internal/mithril-aggregator-client/src/client.rs @@ -0,0 +1,477 @@ +use anyhow::{Context, anyhow}; +use reqwest::{IntoUrl, Response, Url, header::HeaderMap}; +use semver::Version; +use slog::{Logger, error, warn}; +use std::time::Duration; + +use mithril_common::MITHRIL_API_VERSION_HEADER; +use mithril_common::api_version::APIVersionProvider; + +use crate::AggregatorClientResult; +use crate::builder::AggregatorClientBuilder; +use crate::error::AggregatorClientError; +use crate::query::{AggregatorQuery, QueryContext, QueryMethod}; + +const API_VERSION_MISMATCH_WARNING_MESSAGE: &str = "OpenAPI version may be incompatible, please update Mithril client library to the latest version."; + +/// A client to send HTTP requests to a Mithril Aggregator +pub struct AggregatorClient { + pub(super) aggregator_endpoint: Url, + pub(super) api_version_provider: APIVersionProvider, + pub(super) additional_headers: HeaderMap, + pub(super) timeout_duration: Option, + pub(super) client: reqwest::Client, + pub(super) logger: Logger, +} + +impl AggregatorClient { + /// Creates a [AggregatorClientBuilder] to configure a `AggregatorClient`. + // + // This is the same as `AggregatorClient::builder()`. + pub fn builder(aggregator_url: U) -> AggregatorClientBuilder { + AggregatorClientBuilder::new(aggregator_url) + } + + /// Send the given query to the Mithril Aggregator + pub async fn send(&self, query: Q) -> AggregatorClientResult { + // Todo: error handling ? Reuse the version in `warn_if_api_version_mismatch` ? + let current_api_version = self.api_version_provider.compute_current_version().unwrap(); + let mut request_builder = match Q::method() { + QueryMethod::Get => self.client.get(self.join_aggregator_endpoint(&query.route())?), + QueryMethod::Post => self.client.post(self.join_aggregator_endpoint(&query.route())?), + } + .headers(self.additional_headers.clone()) + .header(MITHRIL_API_VERSION_HEADER, current_api_version.to_string()); + + if let Some(body) = query.body() { + request_builder = request_builder.json(&body); + } + + if let Some(timeout) = self.timeout_duration { + request_builder = request_builder.timeout(timeout); + } + + match request_builder.send().await { + Ok(response) => { + self.warn_if_api_version_mismatch(&response); + + let context = QueryContext { + response, + logger: self.logger.clone(), + }; + query.handle_response(context).await + } + Err(err) => Err(AggregatorClientError::RemoteServerUnreachable(anyhow!(err))), + } + } + + fn join_aggregator_endpoint(&self, endpoint: &str) -> AggregatorClientResult { + self.aggregator_endpoint + .join(endpoint) + .with_context(|| { + format!( + "Invalid url when joining given endpoint, '{endpoint}', to aggregator url '{}'", + self.aggregator_endpoint + ) + }) + .map_err(AggregatorClientError::InvalidEndpoint) + } + + /// Check API version mismatch and log a warning if the aggregator's version is more recent. + fn warn_if_api_version_mismatch(&self, response: &Response) { + let remote_aggregator_version = response + .headers() + .get(MITHRIL_API_VERSION_HEADER) + .and_then(|v| v.to_str().ok()) + .and_then(|s| Version::parse(s).ok()); + + let client_version = self.api_version_provider.compute_current_version(); + + match (remote_aggregator_version, client_version) { + (Some(aggregator), Ok(client)) if client < aggregator => { + warn!(self.logger, "{}", API_VERSION_MISMATCH_WARNING_MESSAGE; + "remote_aggregator_version" => %aggregator, + "caller_version" => %client, + ); + } + (Some(_), Err(error)) => { + error!( + self.logger, + "Failed to compute the current API version"; + "error" => error.to_string() + ); + } + _ => {} + } + } +} + +#[cfg(test)] +mod tests { + use http::StatusCode; + + use mithril_common::test::api_version_extensions::ApiVersionProviderTestExtension; + + use crate::test::{TestLogger, setup_server_and_client}; + + use super::*; + + #[derive(Debug, Eq, PartialEq, serde::Deserialize)] + struct TestResponse { + foo: String, + bar: i32, + } + + struct TestGetQuery; + + #[async_trait::async_trait] + impl AggregatorQuery for TestGetQuery { + type Response = TestResponse; + type Body = (); + + fn method() -> QueryMethod { + QueryMethod::Get + } + + fn route(&self) -> String { + "/dummy-get-route".to_string() + } + + async fn handle_response( + &self, + context: QueryContext, + ) -> AggregatorClientResult { + match context.response.status() { + StatusCode::OK => context + .response + .json::() + .await + .map_err(|err| AggregatorClientError::JsonParseFailed(anyhow!(err))), + _ => Err(context.unhandled_status_code().await), + } + } + } + + #[derive(Debug, Clone, Eq, PartialEq, serde::Serialize)] + struct TestBody { + pika: String, + chu: u8, + } + + impl TestBody { + fn new>(pika: P, chu: u8) -> Self { + Self { + pika: pika.into(), + chu, + } + } + } + + struct TestPostQuery { + body: TestBody, + } + + #[async_trait::async_trait] + impl AggregatorQuery for TestPostQuery { + type Response = (); + type Body = TestBody; + + fn method() -> QueryMethod { + QueryMethod::Post + } + + fn route(&self) -> String { + "/dummy-post-route".to_string() + } + + fn body(&self) -> Option { + Some(self.body.clone()) + } + + async fn handle_response( + &self, + context: QueryContext, + ) -> AggregatorClientResult { + match context.response.status() { + StatusCode::CREATED => Ok(()), + _ => Err(context.unhandled_status_code().await), + } + } + } + + #[tokio::test] + async fn test_minimal_get_query() { + let (server, client) = setup_server_and_client(); + server.mock(|when, then| { + when.method(httpmock::Method::GET).path("/dummy-get-route"); + then.status(200).body(r#"{"foo": "bar", "bar": 123}"#); + }); + + let response = client.send(TestGetQuery).await.unwrap(); + + assert_eq!( + response, + TestResponse { + foo: "bar".to_string(), + bar: 123, + } + ) + } + + #[tokio::test] + async fn test_minimal_post_query() { + let (server, client) = setup_server_and_client(); + server.mock(|when, then| { + when.method(httpmock::Method::POST) + .path("/dummy-post-route") + .header("content-type", "application/json") + .body(serde_json::to_string(&TestBody::new("miaouss", 5)).unwrap()); + then.status(201); + }); + + client + .send(TestPostQuery { + body: TestBody::new("miaouss", 5), + }) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_query_send_mithril_api_version_header() { + let (server, mut client) = setup_server_and_client(); + client.api_version_provider = + APIVersionProvider::new_with_default_version(Version::parse("1.2.9").unwrap()); + server.mock(|when, then| { + when.method(httpmock::Method::GET) + .header(MITHRIL_API_VERSION_HEADER, "1.2.9"); + then.status(200).body(r#"{"foo": "a", "bar": 1}"#); + }); + + client.send(TestGetQuery).await.expect("should not fail"); + } + + #[tokio::test] + async fn test_query_send_additional_header_and_dont_override_mithril_api_version_header() { + let (server, mut client) = setup_server_and_client(); + client.api_version_provider = + APIVersionProvider::new_with_default_version(Version::parse("1.2.9").unwrap()); + client.additional_headers = { + let mut headers = HeaderMap::new(); + headers.insert(MITHRIL_API_VERSION_HEADER, "9.4.5".parse().unwrap()); + headers.insert("foo", "bar".parse().unwrap()); + headers + }; + + server.mock(|when, then| { + when.method(httpmock::Method::POST) + .header(MITHRIL_API_VERSION_HEADER, "1.2.9") + .header("foo", "bar"); + then.status(201).body(r#"{"foo": "a", "bar": 1}"#); + }); + + client + .send(TestPostQuery { + body: TestBody::new("miaouss", 3), + }) + .await + .expect("should not fail"); + } + + #[tokio::test] + async fn test_query_timeout() { + let (server, mut client) = setup_server_and_client(); + client.timeout_duration = Some(Duration::from_millis(10)); + let _server_mock = server.mock(|when, then| { + when.method(httpmock::Method::GET); + then.delay(Duration::from_millis(100)); + }); + + let error = client.send(TestGetQuery).await.expect_err("should not fail"); + + assert!( + matches!(error, AggregatorClientError::RemoteServerUnreachable(_)), + "unexpected error type: {error:?}" + ); + } + + mod warn_if_api_version_mismatch { + use http::response::Builder as HttpResponseBuilder; + use reqwest::Response; + + use mithril_common::test::logging::MemoryDrainForTestInspector; + + use super::*; + + fn build_fake_response_with_header, V: Into>( + key: K, + value: V, + ) -> Response { + HttpResponseBuilder::new() + .header(key.into(), value.into()) + .body("whatever") + .unwrap() + .into() + } + + fn assert_api_version_warning_logged, S: Into>( + log_inspector: &MemoryDrainForTestInspector, + aggregator_version: A, + client_version: S, + ) { + assert!(log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE)); + assert!(log_inspector.contains_log(&format!( + "remote_aggregator_version={}", + aggregator_version.into() + ))); + assert!( + log_inspector.contains_log(&format!("caller_version={}", client_version.into())) + ); + } + + #[test] + fn test_logs_warning_when_aggregator_api_version_is_newer() { + let aggregator_version = "2.0.0"; + let client_version = "1.0.0"; + let (logger, log_inspector) = TestLogger::memory(); + let client = AggregatorClient::builder("http://whatever") + .with_logger(logger) + .with_api_version_provider(APIVersionProvider::new_with_default_version( + Version::parse(client_version).unwrap(), + )) + .build() + .unwrap(); + let response = + build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version); + + assert!( + Version::parse(aggregator_version).unwrap() + > Version::parse(client_version).unwrap() + ); + + client.warn_if_api_version_mismatch(&response); + + assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version); + } + + #[test] + fn test_no_warning_logged_when_versions_match() { + let version = "1.0.0"; + let (logger, log_inspector) = TestLogger::memory(); + let client = AggregatorClient::builder("http://whatever") + .with_logger(logger) + .with_api_version_provider(APIVersionProvider::new_with_default_version( + Version::parse(version).unwrap(), + )) + .build() + .unwrap(); + let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, version); + + client.warn_if_api_version_mismatch(&response); + + assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE)); + } + + #[test] + fn test_no_warning_logged_when_aggregator_api_version_is_older() { + let aggregator_version = "1.0.0"; + let client_version = "2.0.0"; + let (logger, log_inspector) = TestLogger::memory(); + let client = AggregatorClient::builder("http://whatever") + .with_logger(logger) + .with_api_version_provider(APIVersionProvider::new_with_default_version( + Version::parse(client_version).unwrap(), + )) + .build() + .unwrap(); + let response = + build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, aggregator_version); + + assert!( + Version::parse(aggregator_version).unwrap() + < Version::parse(client_version).unwrap() + ); + + client.warn_if_api_version_mismatch(&response); + + assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE)); + } + + #[test] + fn test_does_not_log_or_fail_when_header_is_missing() { + let (logger, log_inspector) = TestLogger::memory(); + let client = AggregatorClient::builder("http://whatever") + .with_logger(logger) + .with_api_version_provider(APIVersionProvider::default()) + .build() + .unwrap(); + let response = + build_fake_response_with_header("NotMithrilAPIVersionHeader", "whatever"); + + client.warn_if_api_version_mismatch(&response); + + assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE)); + } + + #[test] + fn test_does_not_log_or_fail_when_header_is_not_a_version() { + let (logger, log_inspector) = TestLogger::memory(); + let client = AggregatorClient::builder("http://whatever") + .with_logger(logger) + .with_api_version_provider(APIVersionProvider::default()) + .build() + .unwrap(); + let response = + build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "not_a_version"); + + client.warn_if_api_version_mismatch(&response); + + assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE)); + } + + #[test] + fn test_logs_error_when_client_version_cannot_be_computed() { + let (logger, log_inspector) = TestLogger::memory(); + let client = AggregatorClient::builder("http://whatever") + .with_logger(logger) + .with_api_version_provider(APIVersionProvider::new_failing()) + .build() + .unwrap(); + let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "1.0.0"); + + client.warn_if_api_version_mismatch(&response); + + assert!(!log_inspector.contains_log(API_VERSION_MISMATCH_WARNING_MESSAGE)); + } + + #[tokio::test] + async fn test_client_log_warning_if_api_version_mismatch() { + let aggregator_version = "2.0.0"; + let client_version = "1.0.0"; + let (server, mut client) = setup_server_and_client(); + let (logger, log_inspector) = TestLogger::memory(); + client.api_version_provider = APIVersionProvider::new_with_default_version( + Version::parse(client_version).unwrap(), + ); + client.logger = logger; + server.mock(|_, then| { + then.status(StatusCode::CREATED.as_u16()) + .header(MITHRIL_API_VERSION_HEADER, aggregator_version); + }); + + assert!( + Version::parse(aggregator_version).unwrap() + > Version::parse(client_version).unwrap() + ); + + client + .send(TestPostQuery { + body: TestBody::new("miaouss", 3), + }) + .await + .unwrap(); + + assert_api_version_warning_logged(&log_inspector, aggregator_version, client_version); + } + } +} diff --git a/internal/mithril-aggregator-client/src/error.rs b/internal/mithril-aggregator-client/src/error.rs new file mode 100644 index 00000000000..4232336107b --- /dev/null +++ b/internal/mithril-aggregator-client/src/error.rs @@ -0,0 +1,255 @@ +use anyhow::anyhow; +use reqwest::{Response, StatusCode, header}; +use thiserror::Error; + +use mithril_common::StdError; +use mithril_common::entities::{ClientError, ServerError}; + +use crate::JSON_CONTENT_TYPE; + +/// Error structure for the Aggregator Client. +#[derive(Error, Debug)] +pub enum AggregatorClientError { + /// Error raised when querying the aggregator returned a 5XX error. + #[error("Internal error of the Aggregator")] + RemoteServerTechnical(#[source] StdError), + + /// Error raised when querying the aggregator returned a 4XX error. + #[error("Invalid request to the Aggregator")] + RemoteServerLogical(#[source] StdError), + + /// Could not reach aggregator. + #[error("Remote server unreachable")] + RemoteServerUnreachable(#[source] StdError), + + /// Unhandled status code + #[error("Unhandled status code: {0}, response text: {1}")] + UnhandledStatusCode(StatusCode, String), + + /// Could not parse response. + #[error("Json parsing failed")] + JsonParseFailed(#[source] StdError), + + /// Failed to join the query endpoint to the aggregator url + #[error("Invalid endpoint")] + InvalidEndpoint(#[source] StdError), + + /// No signer registration round opened yet + #[error("A signer registration round is not opened yet, please try again later")] + RegistrationRoundNotYetOpened(#[source] StdError), +} + +impl AggregatorClientError { + /// Create an `AggregatorClientError` from a response. + /// + /// This method is meant to be used after handling domain-specific cases leaving only + /// 4xx or 5xx status codes. + /// Otherwise, it will return an `UnhandledStatusCode` error. + pub async fn from_response(response: Response) -> Self { + let error_code = response.status(); + + if error_code.is_client_error() { + let root_cause = Self::get_root_cause(response).await; + Self::RemoteServerLogical(anyhow!(root_cause)) + } else if error_code.is_server_error() { + let root_cause = Self::get_root_cause(response).await; + match error_code.as_u16() { + 550 => Self::RegistrationRoundNotYetOpened(anyhow!(root_cause)), + _ => Self::RemoteServerTechnical(anyhow!(root_cause)), + } + } else { + let response_text = response.text().await.unwrap_or_default(); + Self::UnhandledStatusCode(error_code, response_text) + } + } + + async fn get_root_cause(response: Response) -> String { + let error_code = response.status(); + let canonical_reason = error_code.canonical_reason().unwrap_or_default().to_lowercase(); + let is_json = response + .headers() + .get(header::CONTENT_TYPE) + .is_some_and(|ct| JSON_CONTENT_TYPE == ct); + + if is_json { + let json_value: serde_json::Value = response.json().await.unwrap_or_default(); + + if let Ok(client_error) = serde_json::from_value::(json_value.clone()) { + format!( + "{}: {}: {}", + canonical_reason, client_error.label, client_error.message + ) + } else if let Ok(server_error) = + serde_json::from_value::(json_value.clone()) + { + format!("{}: {}", canonical_reason, server_error.message) + } else if json_value.is_null() { + canonical_reason.to_string() + } else { + format!("{canonical_reason}: {json_value}") + } + } else { + let response_text = response.text().await.unwrap_or_default(); + format!("{canonical_reason}: {response_text}") + } + } +} + +#[cfg(test)] +mod tests { + use http::response::Builder as HttpResponseBuilder; + use serde_json::json; + + use super::*; + + macro_rules! assert_error_text_contains { + ($error: expr, $expect_contains: expr) => { + let error = &$error; + assert!( + error.contains($expect_contains), + "Expected error message to contain '{}'\ngot '{error:?}'", + $expect_contains, + ); + }; + } + + fn build_text_response>(status_code: StatusCode, body: T) -> Response { + HttpResponseBuilder::new() + .status(status_code) + .body(body.into()) + .unwrap() + .into() + } + + fn build_json_response(status_code: StatusCode, body: &T) -> Response { + HttpResponseBuilder::new() + .status(status_code) + .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE) + .body(serde_json::to_string(&body).unwrap()) + .unwrap() + .into() + } + + #[tokio::test] + async fn test_4xx_errors_are_handled_as_remote_server_logical() { + let response = build_text_response(StatusCode::BAD_REQUEST, "error text"); + let handled_error = AggregatorClientError::from_response(response).await; + + assert!( + matches!( + handled_error, + AggregatorClientError::RemoteServerLogical(..) + ), + "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'", + ); + } + + #[tokio::test] + async fn test_5xx_errors_are_handled_as_remote_server_technical() { + let response = build_text_response(StatusCode::INTERNAL_SERVER_ERROR, "error text"); + let handled_error = AggregatorClientError::from_response(response).await; + + assert!( + matches!( + handled_error, + AggregatorClientError::RemoteServerTechnical(..) + ), + "Expected error to be RemoteServerLogical\ngot '{handled_error:?}'", + ); + } + + #[tokio::test] + async fn test_550_error_is_handled_as_registration_round_not_yet_opened() { + let response = build_text_response(StatusCode::from_u16(550).unwrap(), "Not yet available"); + let handled_error = AggregatorClientError::from_response(response).await; + + assert!( + matches!( + handled_error, + AggregatorClientError::RegistrationRoundNotYetOpened(..) + ), + "Expected error to be RegistrationRoundNotYetOpened\ngot '{handled_error:?}'", + ); + } + + #[tokio::test] + async fn test_non_4xx_or_5xx_errors_are_handled_as_unhandled_status_code_and_contains_response_text() + { + let response = build_text_response(StatusCode::OK, "ok text"); + let handled_error = AggregatorClientError::from_response(response).await; + + assert!( + matches!( + handled_error, + AggregatorClientError::UnhandledStatusCode(..) if format!("{handled_error:?}").contains("ok text") + ), + "Expected error to be UnhandledStatusCode with 'ok text' in error text\ngot '{handled_error:?}'", + ); + } + + #[tokio::test] + async fn test_root_cause_of_non_json_response_contains_response_plain_text() { + let error_text = "An error occurred; please try again later."; + let response = build_text_response(StatusCode::EXPECTATION_FAILED, error_text); + + assert_error_text_contains!( + AggregatorClientError::get_root_cause(response).await, + "expectation failed: An error occurred; please try again later." + ); + } + + #[tokio::test] + async fn test_root_cause_of_json_formatted_client_error_response_contains_error_label_and_message() + { + let client_error = ClientError::new("label", "message"); + let response = build_json_response(StatusCode::BAD_REQUEST, &client_error); + + assert_error_text_contains!( + AggregatorClientError::get_root_cause(response).await, + "bad request: label: message" + ); + } + + #[tokio::test] + async fn test_root_cause_of_json_formatted_server_error_response_contains_error_label_and_message() + { + let server_error = ServerError::new("message"); + let response = build_json_response(StatusCode::BAD_REQUEST, &server_error); + + assert_error_text_contains!( + AggregatorClientError::get_root_cause(response).await, + "bad request: message" + ); + } + + #[tokio::test] + async fn test_root_cause_of_unknown_formatted_json_response_contains_json_key_value_pairs() { + let response = build_json_response( + StatusCode::INTERNAL_SERVER_ERROR, + &json!({ "second": "unknown", "first": "foreign" }), + ); + + assert_error_text_contains!( + AggregatorClientError::get_root_cause(response).await, + r#"internal server error: {"first":"foreign","second":"unknown"}"# + ); + } + + #[tokio::test] + async fn test_root_cause_with_invalid_json_response_still_contains_response_status_name() { + let response = HttpResponseBuilder::new() + .status(StatusCode::BAD_REQUEST) + .header(header::CONTENT_TYPE, JSON_CONTENT_TYPE) + .body(r#"{"invalid":"unexpected dot", "key": "value".}"#) + .unwrap() + .into(); + + let root_cause = AggregatorClientError::get_root_cause(response).await; + + assert_error_text_contains!(root_cause, "bad request"); + assert!( + !root_cause.contains("bad request: "), + "Expected error message should not contain additional information \ngot '{root_cause:?}'" + ); + } +} diff --git a/internal/mithril-aggregator-client/src/lib.rs b/internal/mithril-aggregator-client/src/lib.rs new file mode 100644 index 00000000000..2549ef9f302 --- /dev/null +++ b/internal/mithril-aggregator-client/src/lib.rs @@ -0,0 +1,20 @@ +#![warn(missing_docs)] +//! This crate provides a client to request data from a Mithril Aggregator. +//! + +mod builder; +mod client; +mod error; +pub mod query; +#[cfg(test)] +mod test; + +pub use builder::AggregatorClientBuilder; +pub use client::AggregatorClient; +pub use error::AggregatorClientError; + +pub(crate) const JSON_CONTENT_TYPE: reqwest::header::HeaderValue = + reqwest::header::HeaderValue::from_static("application/json"); + +/// Aggregator-client result type +pub type AggregatorClientResult = Result; diff --git a/internal/mithril-aggregator-client/src/query/api.rs b/internal/mithril-aggregator-client/src/query/api.rs new file mode 100644 index 00000000000..e083cf54e30 --- /dev/null +++ b/internal/mithril-aggregator-client/src/query/api.rs @@ -0,0 +1,43 @@ +use reqwest::Response; +use serde::de::DeserializeOwned; +use slog::Logger; + +use crate::AggregatorClientResult; +use crate::error::AggregatorClientError; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum QueryMethod { + Get, + Post, +} + +#[cfg_attr(target_family = "wasm", async_trait::async_trait(?Send))] +#[cfg_attr(not(target_family = "wasm"), async_trait::async_trait)] +pub trait AggregatorQuery { + type Response: DeserializeOwned; + type Body: serde::Serialize + Sized; + + fn method() -> QueryMethod; + + fn route(&self) -> String; + + fn body(&self) -> Option { + None + } + + async fn handle_response( + &self, + context: QueryContext, + ) -> AggregatorClientResult; +} + +pub struct QueryContext { + pub(crate) response: Response, + pub(crate) logger: Logger, +} + +impl QueryContext { + pub async fn unhandled_status_code(self) -> AggregatorClientError { + AggregatorClientError::from_response(self.response).await + } +} diff --git a/internal/mithril-aggregator-client/src/query/certificate/get_certificate_details.rs b/internal/mithril-aggregator-client/src/query/certificate/get_certificate_details.rs new file mode 100644 index 00000000000..6ae31f44448 --- /dev/null +++ b/internal/mithril-aggregator-client/src/query/certificate/get_certificate_details.rs @@ -0,0 +1,168 @@ +use anyhow::anyhow; +use async_trait::async_trait; +use reqwest::StatusCode; +use slog::debug; + +use mithril_common::messages::CertificateMessage; + +use crate::AggregatorClientResult; +use crate::error::AggregatorClientError; +use crate::query::{AggregatorQuery, QueryContext, QueryMethod}; + +/// Get the details of a certificate +pub struct CertificateDetailsQuery { + hash: String, +} + +impl CertificateDetailsQuery { + /// Instantiate a query to get a certificate by hash + pub fn by_hash>(hash: H) -> Self { + Self { hash: hash.into() } + } + + /// Instantiate a query to get the latest genesis certificate + pub fn latest_genesis() -> Self { + Self { + hash: "genesis".to_string(), + } + } +} + +#[cfg_attr(target_family = "wasm", async_trait(?Send))] +#[cfg_attr(not(target_family = "wasm"), async_trait)] +impl AggregatorQuery for CertificateDetailsQuery { + type Response = Option; + type Body = (); + + fn method() -> QueryMethod { + QueryMethod::Get + } + + fn route(&self) -> String { + format!("certificate/{}", self.hash) + } + + async fn handle_response( + &self, + context: QueryContext, + ) -> AggregatorClientResult { + debug!(context.logger, "Retrieve certificate details"; "certificate_hash" => %self.hash); + + match context.response.status() { + StatusCode::OK => match context.response.json::().await { + Ok(message) => Ok(Some(message)), + Err(err) => Err(AggregatorClientError::JsonParseFailed(anyhow!(err))), + }, + StatusCode::NOT_FOUND => Ok(None), + _ => Err(context.unhandled_status_code().await), + } + } +} + +#[cfg(test)] +mod tests { + use serde_json::json; + + use mithril_common::test::double::Dummy; + + use crate::test::setup_server_and_client; + + use super::*; + + #[tokio::test] + async fn test_certificates_details_ok_200() { + let (server, client) = setup_server_and_client(); + let expected_message = CertificateMessage::dummy(); + let _server_mock = server.mock(|when, then| { + when.path(format!("/certificate/{}", expected_message.hash)); + then.status(200).body(json!(expected_message).to_string()); + }); + + let fetched_message = client + .send(CertificateDetailsQuery::by_hash(&expected_message.hash)) + .await + .unwrap(); + + assert_eq!(Some(expected_message), fetched_message); + } + + #[tokio::test] + async fn test_certificates_details_ok_404() { + let (server, client) = setup_server_and_client(); + let _server_mock = server.mock(|when, then| { + when.any_request(); + then.status(404); + }); + + let fetched_message = client + .send(CertificateDetailsQuery::by_hash("whatever")) + .await + .unwrap(); + + assert_eq!(None, fetched_message); + } + + #[tokio::test] + async fn test_certificates_details_ko_500() { + let (server, client) = setup_server_and_client(); + let _server_mock = server.mock(|when, then| { + when.any_request(); + then.status(500).body("an error occurred"); + }); + + match client + .send(CertificateDetailsQuery::by_hash("whatever")) + .await + .unwrap_err() + { + AggregatorClientError::RemoteServerTechnical(_) => (), + e => panic!("Expected Aggregator::RemoteServerTechnical error, got '{e:?}'."), + }; + } + + #[tokio::test] + async fn test_latest_genesis_ok_200() { + let (server, client) = setup_server_and_client(); + let genesis_message = CertificateMessage::dummy(); + let _server_mock = server.mock(|when, then| { + when.path("/certificate/genesis"); + then.status(200).body(json!(genesis_message).to_string()); + }); + + let fetched = client.send(CertificateDetailsQuery::latest_genesis()).await.unwrap(); + + assert_eq!(Some(genesis_message), fetched); + } + + #[tokio::test] + async fn test_latest_genesis_ok_404() { + let (server, client) = setup_server_and_client(); + let _server_mock = server.mock(|when, then| { + when.path("/certificate/genesis"); + then.status(404); + }); + + let fetched = client.send(CertificateDetailsQuery::latest_genesis()).await.unwrap(); + + assert_eq!(None, fetched); + } + + #[tokio::test] + async fn test_latest_genesis_ko_500() { + let (server, client) = setup_server_and_client(); + let _server_mock = server.mock(|when, then| { + when.path("/certificate/genesis"); + then.status(500).body("an error occurred"); + }); + + let error = client + .send(CertificateDetailsQuery::latest_genesis()) + .await + .unwrap_err(); + + assert!( + matches!(error, AggregatorClientError::RemoteServerTechnical(_)), + "Expected Aggregator::RemoteServerTechnical error, got {error:?}" + ); + } +} diff --git a/internal/mithril-aggregator-client/src/query/certificate/mod.rs b/internal/mithril-aggregator-client/src/query/certificate/mod.rs new file mode 100644 index 00000000000..c09cabaf25f --- /dev/null +++ b/internal/mithril-aggregator-client/src/query/certificate/mod.rs @@ -0,0 +1,3 @@ +mod get_certificate_details; + +pub use get_certificate_details::*; diff --git a/internal/mithril-aggregator-client/src/query/mod.rs b/internal/mithril-aggregator-client/src/query/mod.rs new file mode 100644 index 00000000000..2cff48897fd --- /dev/null +++ b/internal/mithril-aggregator-client/src/query/mod.rs @@ -0,0 +1,10 @@ +//! Provides queries to retrieve or send data to a Mithril aggregator +//! +//! Available queries +//! - Certificate: Get by hash, get latest genesis certificate +//! +mod api; +mod certificate; + +pub(crate) use api::*; +pub use certificate::*; diff --git a/internal/mithril-aggregator-client/src/test/mod.rs b/internal/mithril-aggregator-client/src/test/mod.rs new file mode 100644 index 00000000000..265f3d6d400 --- /dev/null +++ b/internal/mithril-aggregator-client/src/test/mod.rs @@ -0,0 +1,17 @@ +use httpmock::MockServer; + +use crate::AggregatorClient; + +#[cfg(test)] +mithril_common::define_test_logger!(); + +#[cfg(test)] +pub(crate) fn setup_server_and_client() -> (MockServer, AggregatorClient) { + let server = MockServer::start(); + let client = AggregatorClient::builder(server.base_url()) + .with_logger(TestLogger::stdout()) + .build() + .unwrap(); + + (server, client) +} diff --git a/mithril-aggregator/Cargo.toml b/mithril-aggregator/Cargo.toml index c00570a88d7..0d706f8e19a 100644 --- a/mithril-aggregator/Cargo.toml +++ b/mithril-aggregator/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mithril-aggregator" -version = "0.7.77" +version = "0.7.78" description = "A Mithril Aggregator server" authors = { workspace = true } edition = { workspace = true } diff --git a/mithril-aggregator/src/services/aggregator_client.rs b/mithril-aggregator/src/services/aggregator_client.rs index 836993c8721..02d8a2cfe7e 100644 --- a/mithril-aggregator/src/services/aggregator_client.rs +++ b/mithril-aggregator/src/services/aggregator_client.rs @@ -812,6 +812,7 @@ mod tests { mod warn_if_api_version_mismatch { use std::collections::HashMap; + use mithril_common::test::api_version_extensions::ApiVersionProviderTestExtension; use mithril_common::test::logging::MemoryDrainForTestInspector; use super::*; diff --git a/mithril-common/Cargo.toml b/mithril-common/Cargo.toml index 105bcdafa63..c9cf9bece2b 100644 --- a/mithril-common/Cargo.toml +++ b/mithril-common/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mithril-common" -version = "0.6.12" +version = "0.6.13" description = "Common types, interfaces, and utilities for Mithril nodes." authors = { workspace = true } edition = { workspace = true } diff --git a/mithril-common/src/api_version.rs b/mithril-common/src/api_version.rs index 3d069df2113..c34c35c3430 100644 --- a/mithril-common/src/api_version.rs +++ b/mithril-common/src/api_version.rs @@ -63,18 +63,47 @@ impl APIVersionProvider { versions.sort(); versions } +} + +impl Default for APIVersionProvider { + fn default() -> Self { + struct DiscriminantSourceDefault; + impl ApiVersionDiscriminantSource for DiscriminantSourceDefault { + fn get_discriminant(&self) -> String { + // Return nonexistent discriminant to ensure the default 'openapi.yml' file is used + "nonexistent-discriminant".to_string() + } + } - /// Update open api versions. Test only - pub fn update_open_api_versions( - &mut self, - open_api_versions: HashMap, - ) { + Self::new(Arc::new(DiscriminantSourceDefault)) + } +} + +#[cfg(any(test, feature = "test_tools"))] +impl crate::test::api_version_extensions::ApiVersionProviderTestExtension for APIVersionProvider { + fn update_open_api_versions(&mut self, open_api_versions: HashMap) { self.open_api_versions = open_api_versions; } + + fn new_with_default_version(version: Version) -> APIVersionProvider { + Self { + open_api_versions: HashMap::from([("openapi.yaml".to_string(), version)]), + ..Self::default() + } + } + + fn new_failing() -> APIVersionProvider { + Self { + // Leverage the error raised if the default api version is missing + open_api_versions: HashMap::new(), + ..Self::default() + } + } } #[cfg(test)] mod test { + use crate::test::api_version_extensions::ApiVersionProviderTestExtension; use crate::test::double::DummyApiVersionDiscriminantSource; use super::*; @@ -152,4 +181,29 @@ mod test { assert!(!all_versions_sorted.is_empty()); } + + #[test] + fn default_provider_returns_default_version() { + let provider = APIVersionProvider::default(); + let version = provider.compute_current_version().unwrap(); + + assert_eq!( + get_open_api_versions_mapping().get("openapi.yaml").unwrap(), + &version + ); + } + + #[test] + fn building_provider_with_canned_default_openapi_version() { + let provider = APIVersionProvider::new_with_default_version(Version::new(1, 2, 3)); + let version = provider.compute_current_version().unwrap(); + + assert_eq!(Version::new(1, 2, 3), version); + } + + #[test] + fn building_provider_that_fails_compute_current_version() { + let provider = APIVersionProvider::new_failing(); + provider.compute_current_version().expect_err("Should fail"); + } } diff --git a/mithril-common/src/test/api_version_extensions.rs b/mithril-common/src/test/api_version_extensions.rs new file mode 100644 index 00000000000..0929bfb2104 --- /dev/null +++ b/mithril-common/src/test/api_version_extensions.rs @@ -0,0 +1,18 @@ +//! A set of extension traits to add test utilities to this crate `APIVersionProvider` + +use semver::Version; +use std::collections::HashMap; + +use crate::api_version::{APIVersionProvider, OpenAPIFileName}; + +/// Extension trait adding test utilities to [APIVersionProvider] +pub trait ApiVersionProviderTestExtension { + /// `TEST ONLY` - Replace the open api versions + fn update_open_api_versions(&mut self, open_api_versions: HashMap); + + /// `TEST ONLY` - Set up an ` APIVersionProvider ` with the given version for the `openapi.yaml` file + fn new_with_default_version(version: Version) -> APIVersionProvider; + + /// `TEST ONLY` - Set up an ` APIVersionProvider ` that fails to compute api versions + fn new_failing() -> APIVersionProvider; +} diff --git a/mithril-common/src/test/mod.rs b/mithril-common/src/test/mod.rs index 922ca2f5779..49938f23803 100644 --- a/mithril-common/src/test/mod.rs +++ b/mithril-common/src/test/mod.rs @@ -12,6 +12,7 @@ //! * `temp_dir`: Temporary directory management for tests //! +pub mod api_version_extensions; pub mod builder; pub mod crypto_helper; pub mod double; diff --git a/mithril-signer/Cargo.toml b/mithril-signer/Cargo.toml index 06af998a2ee..ecc5ef080fb 100644 --- a/mithril-signer/Cargo.toml +++ b/mithril-signer/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "mithril-signer" -version = "0.2.262" +version = "0.2.263" description = "A Mithril Signer" authors = { workspace = true } edition = { workspace = true } diff --git a/mithril-signer/src/services/aggregator_client.rs b/mithril-signer/src/services/aggregator_client.rs index d5695a68ba7..1ccd35c0c65 100644 --- a/mithril-signer/src/services/aggregator_client.rs +++ b/mithril-signer/src/services/aggregator_client.rs @@ -1057,6 +1057,8 @@ mod tests { } mod warn_if_api_version_mismatch { + use mithril_common::test::api_version_extensions::ApiVersionProviderTestExtension; + use super::*; fn version_provider_with_open_api_version>(