diff --git a/apps/hermes/src/api.rs b/apps/hermes/src/api.rs index dfac367885..64c22907e5 100644 --- a/apps/hermes/src/api.rs +++ b/apps/hermes/src/api.rs @@ -1,7 +1,12 @@ use { crate::{ config::RunOptions, - state::State, + state::{ + Aggregates, + Benchmarks, + Cache, + Metrics, + }, }, anyhow::Result, axum::{ @@ -24,7 +29,7 @@ mod rest; pub mod types; mod ws; -pub struct ApiState { +pub struct ApiState { pub state: Arc, pub ws: Arc, pub metrics: Arc, @@ -42,12 +47,12 @@ impl Clone for ApiState { } } -impl ApiState { - pub fn new( - state: Arc, - ws_whitelist: Vec, - requester_ip_header_name: String, - ) -> Self { +impl ApiState { + pub fn new(state: Arc, ws_whitelist: Vec, requester_ip_header_name: String) -> Self + where + S: Metrics, + S: Send + Sync + 'static, + { Self { metrics: Arc::new(metrics_middleware::ApiMetrics::new(state.clone())), ws: Arc::new(ws::WsState::new( @@ -61,7 +66,14 @@ impl ApiState { } #[tracing::instrument(skip(opts, state))] -pub async fn spawn(opts: RunOptions, state: Arc) -> Result<()> { +pub async fn spawn(opts: RunOptions, state: Arc) -> Result<()> +where + S: Aggregates, + S: Benchmarks, + S: Cache, + S: Metrics, + S: Send + Sync + 'static, +{ let state = { let opts = opts.clone(); ApiState::new( @@ -79,7 +91,14 @@ pub async fn spawn(opts: RunOptions, state: Arc) -> Result<()> { /// Currently this is based on Axum due to the simplicity and strong ecosystem support for the /// packages they are based on (tokio & hyper). #[tracing::instrument(skip(opts, state))] -pub async fn run(opts: RunOptions, state: ApiState) -> Result<()> { +pub async fn run(opts: RunOptions, state: ApiState) -> Result<()> +where + S: Aggregates, + S: Benchmarks, + S: Cache, + S: Metrics, + S: Send + Sync + 'static, +{ tracing::info!(endpoint = %opts.rpc.listen_addr, "Starting RPC Server."); #[derive(OpenApi)] diff --git a/apps/hermes/src/api/metrics_middleware.rs b/apps/hermes/src/api/metrics_middleware.rs index ce5677e536..be97f25a52 100644 --- a/apps/hermes/src/api/metrics_middleware.rs +++ b/apps/hermes/src/api/metrics_middleware.rs @@ -31,9 +31,7 @@ impl ApiMetrics { pub fn new(state: Arc) -> Self where S: Metrics, - S: Send, - S: Sync, - S: 'static, + S: Send + Sync + 'static, { let new = Self { requests: Family::default(), @@ -81,8 +79,8 @@ pub struct Labels { pub status: u16, } -pub async fn track_metrics( - State(api_state): State, +pub async fn track_metrics( + State(api_state): State>, req: Request, next: Next, ) -> impl IntoResponse { diff --git a/apps/hermes/src/api/rest/v2/price_feeds_metadata.rs b/apps/hermes/src/api/rest/v2/price_feeds_metadata.rs index 9811cbdf3e..7ba9a2262a 100644 --- a/apps/hermes/src/api/rest/v2/price_feeds_metadata.rs +++ b/apps/hermes/src/api/rest/v2/price_feeds_metadata.rs @@ -8,7 +8,7 @@ use { }, ApiState, }, - price_feeds_metadata::PriceFeedMeta, + state::price_feeds_metadata::PriceFeedMeta, }, anyhow::Result, axum::{ diff --git a/apps/hermes/src/api/rest/v2/sse.rs b/apps/hermes/src/api/rest/v2/sse.rs index 6c2f988132..e1d55dde70 100644 --- a/apps/hermes/src/api/rest/v2/sse.rs +++ b/apps/hermes/src/api/rest/v2/sse.rs @@ -95,9 +95,7 @@ pub async fn price_stream_sse_handler( ) -> Result>>, RestError> where S: Aggregates, - S: Sync, - S: Send, - S: 'static, + S: Send + Sync + 'static, { let price_ids: Vec = params.ids.into_iter().map(Into::into).collect(); diff --git a/apps/hermes/src/api/ws.rs b/apps/hermes/src/api/ws.rs index e29ef95eb7..125be8ac1f 100644 --- a/apps/hermes/src/api/ws.rs +++ b/apps/hermes/src/api/ws.rs @@ -13,7 +13,9 @@ use { RequestTime, }, metrics::Metrics, - State, + Benchmarks, + Cache, + PriceFeedMeta, }, anyhow::{ anyhow, @@ -124,9 +126,7 @@ impl WsMetrics { pub fn new(state: Arc) -> Self where S: Metrics, - S: Send, - S: Sync, - S: 'static, + S: Send + Sync + 'static, { let new = Self { interactions: Family::default(), @@ -161,7 +161,11 @@ pub struct WsState { } impl WsState { - pub fn new(whitelist: Vec, requester_ip_header_name: String, state: Arc) -> Self { + pub fn new(whitelist: Vec, requester_ip_header_name: String, state: Arc) -> Self + where + S: Metrics, + S: Send + Sync + 'static, + { Self { subscriber_counter: AtomicUsize::new(0), rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!( @@ -211,11 +215,18 @@ enum ServerResponseMessage { Err { error: String }, } -pub async fn ws_route_handler( +pub async fn ws_route_handler( ws: WebSocketUpgrade, - AxumState(state): AxumState, + AxumState(state): AxumState>, headers: HeaderMap, -) -> impl IntoResponse { +) -> impl IntoResponse +where + S: Aggregates, + S: Benchmarks, + S: Cache, + S: PriceFeedMeta, + S: Send + Sync + 'static, +{ let requester_ip = headers .get(state.ws.requester_ip_header_name.as_str()) .and_then(|value| value.to_str().ok()) @@ -230,6 +241,7 @@ pub async fn ws_route_handler( async fn websocket_handler(stream: WebSocket, state: ApiState, subscriber_ip: Option) where S: Aggregates, + S: Send, { let ws_state = state.ws.clone(); diff --git a/apps/hermes/src/main.rs b/apps/hermes/src/main.rs index 174af203ec..b0a58e7a81 100644 --- a/apps/hermes/src/main.rs +++ b/apps/hermes/src/main.rs @@ -9,7 +9,6 @@ use { }, futures::future::join_all, lazy_static::lazy_static, - state::State, std::io::IsTerminal, tokio::{ spawn, @@ -21,7 +20,6 @@ mod api; mod config; mod metrics_server; mod network; -mod price_feeds_metadata; mod serde; mod state; @@ -53,7 +51,7 @@ async fn init() -> Result<()> { let (update_tx, _) = tokio::sync::broadcast::channel(1000); // Initialize a cache store with a 1000 element circular buffer. - let state = State::new(update_tx.clone(), 1000, opts.benchmarks.endpoint.clone()); + let state = state::new(update_tx.clone(), 1000, opts.benchmarks.endpoint.clone()); // Listen for Ctrl+C so we can set the exit flag and wait for a graceful shutdown. spawn(async move { diff --git a/apps/hermes/src/metrics_server.rs b/apps/hermes/src/metrics_server.rs index b6482b8f54..93b7d3ff93 100644 --- a/apps/hermes/src/metrics_server.rs +++ b/apps/hermes/src/metrics_server.rs @@ -5,10 +5,7 @@ use { crate::{ config::RunOptions, - state::{ - metrics::Metrics, - State as AppState, - }, + state::metrics::Metrics, }, anyhow::Result, axum::{ @@ -23,7 +20,11 @@ use { #[tracing::instrument(skip(opts, state))] -pub async fn run(opts: RunOptions, state: Arc) -> Result<()> { +pub async fn run(opts: RunOptions, state: Arc) -> Result<()> +where + S: Metrics, + S: Send + Sync + 'static, +{ tracing::info!(endpoint = %opts.metrics.server_listen_addr, "Starting Metrics Server."); let app = Router::new(); @@ -44,7 +45,10 @@ pub async fn run(opts: RunOptions, state: Arc) -> Result<()> { Ok(()) } -pub async fn metrics(State(state): State>) -> impl IntoResponse { +pub async fn metrics(State(state): State>) -> impl IntoResponse +where + S: Metrics, +{ let buffer = Metrics::encode(&*state).await; ( [( diff --git a/apps/hermes/src/network/pythnet.rs b/apps/hermes/src/network/pythnet.rs index 929c3ed7a6..401e90fc80 100644 --- a/apps/hermes/src/network/pythnet.rs +++ b/apps/hermes/src/network/pythnet.rs @@ -11,18 +11,17 @@ use { GuardianSet, GuardianSetData, }, - price_feeds_metadata::{ - PriceFeedMeta, - DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL, - }, state::{ aggregate::{ AccumulatorMessages, Aggregates, Update, }, + price_feeds_metadata::{ + PriceFeedMeta, + DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL, + }, wormhole::Wormhole, - State, }, }, anyhow::{ @@ -139,7 +138,12 @@ async fn fetch_bridge_data( } } -pub async fn run(store: Arc, pythnet_ws_endpoint: String) -> Result { +pub async fn run(store: Arc, pythnet_ws_endpoint: String) -> Result +where + S: Aggregates, + S: Wormhole, + S: Send + Sync + 'static, +{ let client = PubsubClient::new(pythnet_ws_endpoint.as_ref()).await?; let config = RpcProgramAccountsConfig { @@ -222,6 +226,7 @@ async fn fetch_existing_guardian_sets( ) -> Result<()> where S: Wormhole, + S: Send + Sync + 'static, { let client = RpcClient::new(pythnet_http_endpoint.to_string()); let bridge = fetch_bridge_data(&client, &wormhole_contract_addr).await?; @@ -261,7 +266,11 @@ where } #[tracing::instrument(skip(opts, state))] -pub async fn spawn(opts: RunOptions, state: Arc) -> Result<()> { +pub async fn spawn(opts: RunOptions, state: Arc) -> Result<()> +where + S: Wormhole, + S: Send + Sync + 'static, +{ tracing::info!(endpoint = opts.pythnet.ws_addr, "Started Pythnet Listener."); // Create RpcClient instance here diff --git a/apps/hermes/src/network/wormhole.rs b/apps/hermes/src/network/wormhole.rs index c29b3bb375..f8e176cabf 100644 --- a/apps/hermes/src/network/wormhole.rs +++ b/apps/hermes/src/network/wormhole.rs @@ -7,10 +7,7 @@ use { crate::{ config::RunOptions, - state::{ - wormhole::Wormhole, - State, - }, + state::wormhole::Wormhole, }, anyhow::{ anyhow, @@ -118,7 +115,11 @@ mod proto { // Launches the Wormhole gRPC service. #[tracing::instrument(skip(opts, state))] -pub async fn spawn(opts: RunOptions, state: Arc) -> Result<()> { +pub async fn spawn(opts: RunOptions, state: Arc) -> Result<()> +where + S: Wormhole, + S: Send + Sync + 'static, +{ let mut exit = crate::EXIT.subscribe(); loop { let current_time = Instant::now(); @@ -142,9 +143,7 @@ pub async fn spawn(opts: RunOptions, state: Arc) -> Result<()> { async fn run(opts: RunOptions, state: Arc) -> Result where S: Wormhole, - S: Sync, - S: Send, - S: 'static, + S: Send + Sync + 'static, { let mut client = SpyRpcServiceClient::connect(opts.wormhole.spy_rpc_addr).await?; let mut stream = client diff --git a/apps/hermes/src/state.rs b/apps/hermes/src/state.rs index de9aa1ebff..59cf19435d 100644 --- a/apps/hermes/src/state.rs +++ b/apps/hermes/src/state.rs @@ -9,9 +9,9 @@ use { benchmarks::BenchmarksState, cache::CacheState, metrics::MetricsState, + price_feeds_metadata::PriceFeedMetaState, wormhole::WormholeState, }, - crate::price_feeds_metadata::PriceFeedMetaState, prometheus_client::registry::Registry, reqwest::Url, std::sync::Arc, @@ -22,9 +22,25 @@ pub mod aggregate; pub mod benchmarks; pub mod cache; pub mod metrics; +pub mod price_feeds_metadata; pub mod wormhole; -pub struct State { +// Expose State interfaces and types for other modules. +pub use { + aggregate::Aggregates, + benchmarks::Benchmarks, + cache::Cache, + metrics::Metrics, + price_feeds_metadata::PriceFeedMeta, + wormhole::Wormhole, +}; + +/// State contains all relevant shared application state. +/// +/// This type is intentionally not exposed, forcing modules to interface with the +/// various API's using the provided traits. This is done to enforce separation of +/// concerns and to avoid direct manipulation of state. +struct State { /// State for the `Cache` service for short-lived storage of updates. pub cache: CacheState, @@ -44,36 +60,40 @@ pub struct State { pub metrics: MetricsState, } -impl State { - pub fn new( - update_tx: Sender, - cache_size: u64, - benchmarks_endpoint: Option, - ) -> Arc { - let mut metrics_registry = Registry::default(); - Arc::new(Self { - cache: CacheState::new(cache_size), - benchmarks: BenchmarksState::new(benchmarks_endpoint), - price_feed_meta: PriceFeedMetaState::new(), - aggregates: AggregateState::new(update_tx, &mut metrics_registry), - wormhole: WormholeState::new(), - metrics: MetricsState::new(metrics_registry), - }) - } +pub fn new( + update_tx: Sender, + cache_size: u64, + benchmarks_endpoint: Option, +) -> Arc { + let mut metrics_registry = Registry::default(); + Arc::new(State { + cache: CacheState::new(cache_size), + benchmarks: BenchmarksState::new(benchmarks_endpoint), + price_feed_meta: PriceFeedMetaState::new(), + aggregates: AggregateState::new(update_tx, &mut metrics_registry), + wormhole: WormholeState::new(), + metrics: MetricsState::new(metrics_registry), + }) } #[cfg(test)] pub mod test { use { - self::wormhole::Wormhole, - super::*, + super::{ + aggregate::AggregationEvent, + Aggregates, + Wormhole, + }, crate::network::wormhole::GuardianSet, + std::sync::Arc, tokio::sync::broadcast::Receiver, }; - pub async fn setup_state(cache_size: u64) -> (Arc, Receiver) { + pub async fn setup_state( + cache_size: u64, + ) -> (Arc, Receiver) { let (update_tx, update_rx) = tokio::sync::broadcast::channel(1000); - let state = State::new(update_tx, cache_size, None); + let state = super::new(update_tx, cache_size, None); // Add an initial guardian set with public key 0 Wormhole::update_guardian_set( diff --git a/apps/hermes/src/state/aggregate.rs b/apps/hermes/src/state/aggregate.rs index 89569d25b0..55a2fb692a 100644 --- a/apps/hermes/src/state/aggregate.rs +++ b/apps/hermes/src/state/aggregate.rs @@ -20,7 +20,6 @@ use { }, crate::{ network::wormhole::VaaBytes, - price_feeds_metadata::PriceFeedMeta, state::{ benchmarks::Benchmarks, cache::{ @@ -28,6 +27,7 @@ use { MessageState, MessageStateFilter, }, + price_feeds_metadata::PriceFeedMeta, State, }, }, @@ -612,7 +612,11 @@ mod test { } } - pub async fn store_multiple_concurrent_valid_updates(state: Arc, updates: Vec) { + pub async fn store_multiple_concurrent_valid_updates(state: Arc, updates: Vec) + where + S: Aggregates, + S: Send + Sync + 'static, + { let res = join_all(updates.into_iter().map(|u| state.store_update(u))).await; // Check that all store_update calls succeeded assert!(res.into_iter().all(|r| r.is_ok())); diff --git a/apps/hermes/src/price_feeds_metadata.rs b/apps/hermes/src/state/price_feeds_metadata.rs similarity index 96% rename from apps/hermes/src/price_feeds_metadata.rs rename to apps/hermes/src/state/price_feeds_metadata.rs index 623144970e..312f31049f 100644 --- a/apps/hermes/src/price_feeds_metadata.rs +++ b/apps/hermes/src/state/price_feeds_metadata.rs @@ -12,16 +12,11 @@ use { pub const DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL: u64 = 600; +#[derive(Default)] pub struct PriceFeedMetaState { pub data: RwLock>, } -impl Default for PriceFeedMetaState { - fn default() -> Self { - Self::new() - } -} - impl PriceFeedMetaState { pub fn new() -> Self { Self {