Skip to content

Commit 54c8323

Browse files
committed
refactor(hermes): make State hidden to force APIs
1 parent b6f5bf1 commit 54c8323

File tree

12 files changed

+137
-81
lines changed

12 files changed

+137
-81
lines changed

apps/hermes/src/api.rs

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
use {
22
crate::{
33
config::RunOptions,
4-
state::State,
4+
state::{
5+
Aggregates,
6+
Benchmarks,
7+
Cache,
8+
Metrics,
9+
},
510
},
611
anyhow::Result,
712
axum::{
@@ -24,7 +29,7 @@ mod rest;
2429
pub mod types;
2530
mod ws;
2631

27-
pub struct ApiState<S = State> {
32+
pub struct ApiState<S> {
2833
pub state: Arc<S>,
2934
pub ws: Arc<ws::WsState>,
3035
pub metrics: Arc<metrics_middleware::ApiMetrics>,
@@ -42,12 +47,12 @@ impl<S> Clone for ApiState<S> {
4247
}
4348
}
4449

45-
impl ApiState<State> {
46-
pub fn new(
47-
state: Arc<State>,
48-
ws_whitelist: Vec<IpNet>,
49-
requester_ip_header_name: String,
50-
) -> Self {
50+
impl<S> ApiState<S> {
51+
pub fn new(state: Arc<S>, ws_whitelist: Vec<IpNet>, requester_ip_header_name: String) -> Self
52+
where
53+
S: Metrics,
54+
S: Send + Sync + 'static,
55+
{
5156
Self {
5257
metrics: Arc::new(metrics_middleware::ApiMetrics::new(state.clone())),
5358
ws: Arc::new(ws::WsState::new(
@@ -61,7 +66,14 @@ impl ApiState<State> {
6166
}
6267

6368
#[tracing::instrument(skip(opts, state))]
64-
pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
69+
pub async fn spawn<S>(opts: RunOptions, state: Arc<S>) -> Result<()>
70+
where
71+
S: Aggregates,
72+
S: Benchmarks,
73+
S: Cache,
74+
S: Metrics,
75+
S: Send + Sync + 'static,
76+
{
6577
let state = {
6678
let opts = opts.clone();
6779
ApiState::new(
@@ -79,7 +91,14 @@ pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
7991
/// Currently this is based on Axum due to the simplicity and strong ecosystem support for the
8092
/// packages they are based on (tokio & hyper).
8193
#[tracing::instrument(skip(opts, state))]
82-
pub async fn run(opts: RunOptions, state: ApiState) -> Result<()> {
94+
pub async fn run<S>(opts: RunOptions, state: ApiState<S>) -> Result<()>
95+
where
96+
S: Aggregates,
97+
S: Benchmarks,
98+
S: Cache,
99+
S: Metrics,
100+
S: Send + Sync + 'static,
101+
{
83102
tracing::info!(endpoint = %opts.rpc.listen_addr, "Starting RPC Server.");
84103

85104
#[derive(OpenApi)]

apps/hermes/src/api/metrics_middleware.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,7 @@ impl ApiMetrics {
3131
pub fn new<S>(state: Arc<S>) -> Self
3232
where
3333
S: Metrics,
34-
S: Send,
35-
S: Sync,
36-
S: 'static,
34+
S: Send + Sync + 'static,
3735
{
3836
let new = Self {
3937
requests: Family::default(),
@@ -81,8 +79,8 @@ pub struct Labels {
8179
pub status: u16,
8280
}
8381

84-
pub async fn track_metrics<B>(
85-
State(api_state): State<ApiState>,
82+
pub async fn track_metrics<B, S>(
83+
State(api_state): State<ApiState<S>>,
8684
req: Request<B>,
8785
next: Next<B>,
8886
) -> impl IntoResponse {

apps/hermes/src/api/rest/v2/price_feeds_metadata.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use {
88
},
99
ApiState,
1010
},
11-
price_feeds_metadata::PriceFeedMeta,
11+
state::price_feeds_metadata::PriceFeedMeta,
1212
},
1313
anyhow::Result,
1414
axum::{

apps/hermes/src/api/rest/v2/sse.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,7 @@ pub async fn price_stream_sse_handler<S>(
9595
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, RestError>
9696
where
9797
S: Aggregates,
98-
S: Sync,
99-
S: Send,
100-
S: 'static,
98+
S: Send + Sync + 'static,
10199
{
102100
let price_ids: Vec<PriceIdentifier> = params.ids.into_iter().map(Into::into).collect();
103101

apps/hermes/src/api/ws.rs

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ use {
1313
RequestTime,
1414
},
1515
metrics::Metrics,
16-
State,
16+
Benchmarks,
17+
Cache,
18+
PriceFeedMeta,
1719
},
1820
anyhow::{
1921
anyhow,
@@ -124,9 +126,7 @@ impl WsMetrics {
124126
pub fn new<S>(state: Arc<S>) -> Self
125127
where
126128
S: Metrics,
127-
S: Send,
128-
S: Sync,
129-
S: 'static,
129+
S: Send + Sync + 'static,
130130
{
131131
let new = Self {
132132
interactions: Family::default(),
@@ -161,7 +161,11 @@ pub struct WsState {
161161
}
162162

163163
impl WsState {
164-
pub fn new(whitelist: Vec<IpNet>, requester_ip_header_name: String, state: Arc<State>) -> Self {
164+
pub fn new<S>(whitelist: Vec<IpNet>, requester_ip_header_name: String, state: Arc<S>) -> Self
165+
where
166+
S: Metrics,
167+
S: Send + Sync + 'static,
168+
{
165169
Self {
166170
subscriber_counter: AtomicUsize::new(0),
167171
rate_limiter: RateLimiter::dashmap(Quota::per_second(nonzero!(
@@ -211,11 +215,18 @@ enum ServerResponseMessage {
211215
Err { error: String },
212216
}
213217

214-
pub async fn ws_route_handler(
218+
pub async fn ws_route_handler<S>(
215219
ws: WebSocketUpgrade,
216-
AxumState(state): AxumState<super::ApiState>,
220+
AxumState(state): AxumState<ApiState<S>>,
217221
headers: HeaderMap,
218-
) -> impl IntoResponse {
222+
) -> impl IntoResponse
223+
where
224+
S: Aggregates,
225+
S: Benchmarks,
226+
S: Cache,
227+
S: PriceFeedMeta,
228+
S: Send + Sync + 'static,
229+
{
219230
let requester_ip = headers
220231
.get(state.ws.requester_ip_header_name.as_str())
221232
.and_then(|value| value.to_str().ok())
@@ -230,6 +241,7 @@ pub async fn ws_route_handler(
230241
async fn websocket_handler<S>(stream: WebSocket, state: ApiState<S>, subscriber_ip: Option<IpAddr>)
231242
where
232243
S: Aggregates,
244+
S: Send,
233245
{
234246
let ws_state = state.ws.clone();
235247

apps/hermes/src/main.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ use {
99
},
1010
futures::future::join_all,
1111
lazy_static::lazy_static,
12-
state::State,
1312
std::io::IsTerminal,
1413
tokio::{
1514
spawn,
@@ -21,7 +20,6 @@ mod api;
2120
mod config;
2221
mod metrics_server;
2322
mod network;
24-
mod price_feeds_metadata;
2523
mod serde;
2624
mod state;
2725

@@ -53,7 +51,7 @@ async fn init() -> Result<()> {
5351
let (update_tx, _) = tokio::sync::broadcast::channel(1000);
5452

5553
// Initialize a cache store with a 1000 element circular buffer.
56-
let state = State::new(update_tx.clone(), 1000, opts.benchmarks.endpoint.clone());
54+
let state = state::new(update_tx.clone(), 1000, opts.benchmarks.endpoint.clone());
5755

5856
// Listen for Ctrl+C so we can set the exit flag and wait for a graceful shutdown.
5957
spawn(async move {

apps/hermes/src/metrics_server.rs

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,7 @@
55
use {
66
crate::{
77
config::RunOptions,
8-
state::{
9-
metrics::Metrics,
10-
State as AppState,
11-
},
8+
state::metrics::Metrics,
129
},
1310
anyhow::Result,
1411
axum::{
@@ -23,7 +20,11 @@ use {
2320

2421

2522
#[tracing::instrument(skip(opts, state))]
26-
pub async fn run(opts: RunOptions, state: Arc<AppState>) -> Result<()> {
23+
pub async fn run<S>(opts: RunOptions, state: Arc<S>) -> Result<()>
24+
where
25+
S: Metrics,
26+
S: Send + Sync + 'static,
27+
{
2728
tracing::info!(endpoint = %opts.metrics.server_listen_addr, "Starting Metrics Server.");
2829

2930
let app = Router::new();
@@ -44,7 +45,10 @@ pub async fn run(opts: RunOptions, state: Arc<AppState>) -> Result<()> {
4445
Ok(())
4546
}
4647

47-
pub async fn metrics(State(state): State<Arc<AppState>>) -> impl IntoResponse {
48+
pub async fn metrics<S>(State(state): State<Arc<S>>) -> impl IntoResponse
49+
where
50+
S: Metrics,
51+
{
4852
let buffer = Metrics::encode(&*state).await;
4953
(
5054
[(

apps/hermes/src/network/pythnet.rs

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,17 @@ use {
1111
GuardianSet,
1212
GuardianSetData,
1313
},
14-
price_feeds_metadata::{
15-
PriceFeedMeta,
16-
DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL,
17-
},
1814
state::{
1915
aggregate::{
2016
AccumulatorMessages,
2117
Aggregates,
2218
Update,
2319
},
20+
price_feeds_metadata::{
21+
PriceFeedMeta,
22+
DEFAULT_PRICE_FEEDS_CACHE_UPDATE_INTERVAL,
23+
},
2424
wormhole::Wormhole,
25-
State,
2625
},
2726
},
2827
anyhow::{
@@ -139,7 +138,12 @@ async fn fetch_bridge_data(
139138
}
140139
}
141140

142-
pub async fn run(store: Arc<State>, pythnet_ws_endpoint: String) -> Result<!> {
141+
pub async fn run<S>(store: Arc<S>, pythnet_ws_endpoint: String) -> Result<!>
142+
where
143+
S: Aggregates,
144+
S: Wormhole,
145+
S: Send + Sync + 'static,
146+
{
143147
let client = PubsubClient::new(pythnet_ws_endpoint.as_ref()).await?;
144148

145149
let config = RpcProgramAccountsConfig {
@@ -222,6 +226,7 @@ async fn fetch_existing_guardian_sets<S>(
222226
) -> Result<()>
223227
where
224228
S: Wormhole,
229+
S: Send + Sync + 'static,
225230
{
226231
let client = RpcClient::new(pythnet_http_endpoint.to_string());
227232
let bridge = fetch_bridge_data(&client, &wormhole_contract_addr).await?;
@@ -261,7 +266,11 @@ where
261266
}
262267

263268
#[tracing::instrument(skip(opts, state))]
264-
pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
269+
pub async fn spawn<S>(opts: RunOptions, state: Arc<S>) -> Result<()>
270+
where
271+
S: Wormhole,
272+
S: Send + Sync + 'static,
273+
{
265274
tracing::info!(endpoint = opts.pythnet.ws_addr, "Started Pythnet Listener.");
266275

267276
// Create RpcClient instance here

apps/hermes/src/network/wormhole.rs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77
use {
88
crate::{
99
config::RunOptions,
10-
state::{
11-
wormhole::Wormhole,
12-
State,
13-
},
10+
state::wormhole::Wormhole,
1411
},
1512
anyhow::{
1613
anyhow,
@@ -118,7 +115,11 @@ mod proto {
118115

119116
// Launches the Wormhole gRPC service.
120117
#[tracing::instrument(skip(opts, state))]
121-
pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
118+
pub async fn spawn<S>(opts: RunOptions, state: Arc<S>) -> Result<()>
119+
where
120+
S: Wormhole,
121+
S: Send + Sync + 'static,
122+
{
122123
let mut exit = crate::EXIT.subscribe();
123124
loop {
124125
let current_time = Instant::now();
@@ -142,9 +143,7 @@ pub async fn spawn(opts: RunOptions, state: Arc<State>) -> Result<()> {
142143
async fn run<S>(opts: RunOptions, state: Arc<S>) -> Result<!>
143144
where
144145
S: Wormhole,
145-
S: Sync,
146-
S: Send,
147-
S: 'static,
146+
S: Send + Sync + 'static,
148147
{
149148
let mut client = SpyRpcServiceClient::connect(opts.wormhole.spy_rpc_addr).await?;
150149
let mut stream = client

0 commit comments

Comments
 (0)