diff --git a/hermes/src/aggregate.rs b/hermes/src/aggregate.rs index 673637502a..81b0b56a25 100644 --- a/hermes/src/aggregate.rs +++ b/hermes/src/aggregate.rs @@ -426,7 +426,7 @@ where pub async fn is_ready(state: &State) -> bool { let metadata = state.aggregate_state.read().await; - let price_feeds_metadata = state.price_feeds_metadata.read().await; + let price_feeds_metadata = state.price_feed_meta.data.read().await; let has_completed_recently = match metadata.latest_completed_update_at.as_ref() { Some(latest_completed_update_time) => { @@ -456,7 +456,7 @@ mod test { super::*, crate::{ api::types::PriceFeedMetadata, - price_feeds_metadata::store_price_feeds_metadata, + price_feeds_metadata::PriceFeedMeta, state::test::setup_state, }, futures::future::join_all, @@ -809,15 +809,13 @@ mod test { // Add a dummy price feeds metadata - store_price_feeds_metadata( - &state, - &[PriceFeedMetadata { + state + .store_price_feeds_metadata(&[PriceFeedMetadata { id: PriceIdentifier::new([100; 32]), attributes: Default::default(), - }], - ) - .await - .unwrap(); + }]) + .await + .unwrap(); // Check the state is ready assert!(is_ready(&state).await); diff --git a/hermes/src/api.rs b/hermes/src/api.rs index d0031e94f5..edff3f1dd8 100644 --- a/hermes/src/api.rs +++ b/hermes/src/api.rs @@ -26,15 +26,27 @@ mod rest; pub mod types; mod ws; -#[derive(Clone)] -pub struct ApiState { - pub state: Arc, +pub struct ApiState { + pub state: Arc, pub ws: Arc, pub metrics: Arc, pub update_tx: Sender, } -impl ApiState { +/// Manually implement `Clone` as the derive macro will try and slap `Clone` on +/// `State` which should not be Clone. +impl Clone for ApiState { + fn clone(&self) -> Self { + Self { + state: self.state.clone(), + ws: self.ws.clone(), + metrics: self.metrics.clone(), + update_tx: self.update_tx.clone(), + } + } +} + +impl ApiState { pub fn new( state: Arc, ws_whitelist: Vec, diff --git a/hermes/src/api/rest/v2/price_feeds_metadata.rs b/hermes/src/api/rest/v2/price_feeds_metadata.rs index ff05be6bd2..9811cbdf3e 100644 --- a/hermes/src/api/rest/v2/price_feeds_metadata.rs +++ b/hermes/src/api/rest/v2/price_feeds_metadata.rs @@ -6,8 +6,9 @@ use { AssetType, PriceFeedMetadata, }, + ApiState, }, - price_feeds_metadata::get_price_feeds_metadata, + price_feeds_metadata::PriceFeedMeta, }, anyhow::Result, axum::{ @@ -46,19 +47,23 @@ pub struct PriceFeedsMetadataQueryParams { PriceFeedsMetadataQueryParams ) )] -pub async fn price_feeds_metadata( - State(state): State, +pub async fn price_feeds_metadata( + State(state): State>, QsQuery(params): QsQuery, -) -> Result>, RestError> { - let price_feeds_metadata = - get_price_feeds_metadata(&state.state, params.query, params.asset_type) - .await - .map_err(|e| { - tracing::warn!("RPC connection error: {}", e); - RestError::RpcConnectionError { - message: format!("RPC connection error: {}", e), - } - })?; +) -> Result>, RestError> +where + S: PriceFeedMeta, +{ + let state = &state.state; + let price_feeds_metadata = state + .get_price_feeds_metadata(params.query, params.asset_type) + .await + .map_err(|e| { + tracing::warn!("RPC connection error: {}", e); + RestError::RpcConnectionError { + message: format!("RPC connection error: {}", e), + } + })?; Ok(Json(price_feeds_metadata)) } diff --git a/hermes/src/network/pythnet.rs b/hermes/src/network/pythnet.rs index a8380c75c2..fa410ce3bc 100644 --- a/hermes/src/network/pythnet.rs +++ b/hermes/src/network/pythnet.rs @@ -17,7 +17,7 @@ use { GuardianSetData, }, price_feeds_metadata::{ - store_price_feeds_metadata, + PriceFeedMeta, DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL, }, state::State, @@ -353,13 +353,18 @@ pub async fn spawn(opts: RunOptions, state: Arc) -> Result<()> { } -pub async fn fetch_and_store_price_feeds_metadata( - state: &State, +pub async fn fetch_and_store_price_feeds_metadata( + state: &S, mapping_address: &Pubkey, rpc_client: &RpcClient, -) -> Result> { +) -> Result> +where + S: PriceFeedMeta, +{ let price_feeds_metadata = fetch_price_feeds_metadata(mapping_address, rpc_client).await?; - store_price_feeds_metadata(state, &price_feeds_metadata).await?; + state + .store_price_feeds_metadata(&price_feeds_metadata) + .await?; Ok(price_feeds_metadata) } diff --git a/hermes/src/price_feeds_metadata.rs b/hermes/src/price_feeds_metadata.rs index 9a6a62760b..cb0d51733b 100644 --- a/hermes/src/price_feeds_metadata.rs +++ b/hermes/src/price_feeds_metadata.rs @@ -7,49 +7,88 @@ use { state::State, }, anyhow::Result, + tokio::sync::RwLock, }; pub const DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL: u64 = 600; -pub async fn retrieve_price_feeds_metadata(state: &State) -> Result> { - let price_feeds_metadata = state.price_feeds_metadata.read().await; - Ok(price_feeds_metadata.clone()) +pub struct PriceFeedMetaState { + pub data: RwLock>, } -pub async fn store_price_feeds_metadata( - state: &State, - price_feeds_metadata: &[PriceFeedMetadata], -) -> Result<()> { - let mut price_feeds_metadata_write_guard = state.price_feeds_metadata.write().await; - *price_feeds_metadata_write_guard = price_feeds_metadata.to_vec(); - Ok(()) +impl PriceFeedMetaState { + pub fn new() -> Self { + Self { + data: RwLock::new(Vec::new()), + } + } } +/// Allow downcasting State into CacheState for functions that depend on the `Cache` service. +impl<'a> From<&'a State> for &'a PriceFeedMetaState { + fn from(state: &'a State) -> &'a PriceFeedMetaState { + &state.price_feed_meta + } +} + +pub trait PriceFeedMeta { + async fn retrieve_price_feeds_metadata(&self) -> Result>; + async fn store_price_feeds_metadata( + &self, + price_feeds_metadata: &[PriceFeedMetadata], + ) -> Result<()>; + async fn get_price_feeds_metadata( + &self, + query: Option, + asset_type: Option, + ) -> Result>; +} -pub async fn get_price_feeds_metadata( - state: &State, - query: Option, - asset_type: Option, -) -> Result> { - let mut price_feeds_metadata = retrieve_price_feeds_metadata(state).await?; - - // Filter by query if provided - if let Some(query_str) = &query { - price_feeds_metadata.retain(|feed| { - feed.attributes.get("symbol").map_or(false, |symbol| { - symbol.to_lowercase().contains(&query_str.to_lowercase()) - }) - }); +impl PriceFeedMeta for T +where + for<'a> &'a T: Into<&'a PriceFeedMetaState>, + T: Sync, +{ + async fn retrieve_price_feeds_metadata(&self) -> Result> { + let price_feeds_metadata = self.into().data.read().await; + Ok(price_feeds_metadata.clone()) } - // Filter by asset_type if provided - if let Some(asset_type) = &asset_type { - price_feeds_metadata.retain(|feed| { - feed.attributes.get("asset_type").map_or(false, |type_str| { - type_str.to_lowercase() == asset_type.to_string().to_lowercase() - }) - }); + async fn store_price_feeds_metadata( + &self, + price_feeds_metadata: &[PriceFeedMetadata], + ) -> Result<()> { + let mut price_feeds_metadata_write_guard = self.into().data.write().await; + *price_feeds_metadata_write_guard = price_feeds_metadata.to_vec(); + Ok(()) } - Ok(price_feeds_metadata) + + async fn get_price_feeds_metadata( + &self, + query: Option, + asset_type: Option, + ) -> Result> { + let mut price_feeds_metadata = self.retrieve_price_feeds_metadata().await?; + + // Filter by query if provided + if let Some(query_str) = &query { + price_feeds_metadata.retain(|feed| { + feed.attributes.get("symbol").map_or(false, |symbol| { + symbol.to_lowercase().contains(&query_str.to_lowercase()) + }) + }); + } + + // Filter by asset_type if provided + if let Some(asset_type) = &asset_type { + price_feeds_metadata.retain(|feed| { + feed.attributes.get("asset_type").map_or(false, |type_str| { + type_str.to_lowercase() == asset_type.to_string().to_lowercase() + }) + }); + } + + Ok(price_feeds_metadata) + } } diff --git a/hermes/src/state.rs b/hermes/src/state.rs index e6d2ca2b1e..4fc46714d3 100644 --- a/hermes/src/state.rs +++ b/hermes/src/state.rs @@ -10,8 +10,8 @@ use { AggregateState, AggregationEvent, }, - api::types::PriceFeedMetadata, network::wormhole::GuardianSet, + price_feeds_metadata::PriceFeedMetaState, }, prometheus_client::registry::Registry, reqwest::Url, @@ -38,6 +38,9 @@ pub struct State { /// State for the `Benchmarks` service for looking up historical updates. pub benchmarks: BenchmarksState, + /// State for the `PriceFeedMeta` service for looking up metadata related to Pyth price feeds. + pub price_feed_meta: PriceFeedMetaState, + /// Sequence numbers of lately observed Vaas. Store uses this set /// to ignore the previously observed Vaas as a performance boost. pub observed_vaa_seqs: RwLock>, @@ -53,9 +56,6 @@ pub struct State { /// Metrics registry pub metrics_registry: RwLock, - - /// Price feeds metadata - pub price_feeds_metadata: RwLock>, } impl State { @@ -66,14 +66,14 @@ impl State { ) -> Arc { let mut metrics_registry = Registry::default(); Arc::new(Self { - cache: CacheState::new(cache_size), - benchmarks: BenchmarksState::new(benchmarks_endpoint), - observed_vaa_seqs: RwLock::new(Default::default()), - guardian_set: RwLock::new(Default::default()), - api_update_tx: update_tx, - aggregate_state: RwLock::new(AggregateState::new(&mut metrics_registry)), - metrics_registry: RwLock::new(metrics_registry), - price_feeds_metadata: RwLock::new(Default::default()), + cache: CacheState::new(cache_size), + benchmarks: BenchmarksState::new(benchmarks_endpoint), + price_feed_meta: PriceFeedMetaState::new(), + observed_vaa_seqs: RwLock::new(Default::default()), + guardian_set: RwLock::new(Default::default()), + api_update_tx: update_tx, + aggregate_state: RwLock::new(AggregateState::new(&mut metrics_registry)), + metrics_registry: RwLock::new(metrics_registry), }) } }