11use {
2+ super :: State ,
23 crate :: aggregate:: {
34 wormhole_merkle:: WormholeMerkleState ,
45 AccumulatorMessages ,
@@ -96,79 +97,42 @@ pub enum MessageStateFilter {
9697 Only ( MessageType ) ,
9798}
9899
99- pub struct Cache {
100- /// Accumulator messages cache
101- ///
102- /// We do not write to this cache much, so we can use a simple RwLock instead of a DashMap.
103- accumulator_messages_cache : Arc < RwLock < BTreeMap < Slot , AccumulatorMessages > > > ,
104-
105- /// Wormhole merkle state cache
106- ///
107- /// We do not write to this cache much, so we can use a simple RwLock instead of a DashMap.
108- wormhole_merkle_state_cache : Arc < RwLock < BTreeMap < Slot , WormholeMerkleState > > > ,
100+ /// A Cache of AccumulatorMessage by slot. We do not write to this cache much, so we can use a simple RwLock instead of a DashMap.
101+ type AccumulatorMessagesCache = Arc < RwLock < BTreeMap < Slot , AccumulatorMessages > > > ;
109102
110- message_cache : Arc < RwLock < HashMap < MessageStateKey , BTreeMap < MessageStateTime , MessageState > > > > ,
111- cache_size : u64 ,
112- }
113-
114- async fn retrieve_message_state (
115- cache : & Cache ,
116- key : MessageStateKey ,
117- request_time : RequestTime ,
118- ) -> Option < MessageState > {
119- match cache. message_cache . read ( ) . await . get ( & key) {
120- Some ( key_cache) => {
121- match request_time {
122- RequestTime :: Latest => key_cache. last_key_value ( ) . map ( |( _, v) | v) . cloned ( ) ,
123- RequestTime :: FirstAfter ( time) => {
124- // If the requested time is before the first element in the vector, we are
125- // not sure that the first element is the closest one.
126- if let Some ( ( _, oldest_record_value) ) = key_cache. first_key_value ( ) {
127- if time < oldest_record_value. time ( ) . publish_time {
128- return None ;
129- }
130- }
103+ /// A Cache of WormholeMerkleState by slot. We do not write to this cache much, so we can use a simple RwLock instead of a DashMap.
104+ type WormholeMerkleStateCache = Arc < RwLock < BTreeMap < Slot , WormholeMerkleState > > > ;
131105
132- let lookup_time = MessageStateTime {
133- publish_time : time,
134- slot : 0 ,
135- } ;
106+ /// A Cache of `Time<->MessageState` by feed id.
107+ type MessageCache = Arc < RwLock < HashMap < MessageStateKey , BTreeMap < MessageStateTime , MessageState > > > > ;
136108
137- // Get the first element that is greater than or equal to the lookup time.
138- key_cache
139- . lower_bound ( Bound :: Included ( & lookup_time) )
140- . peek_next ( )
141- . map ( |( _, v) | v)
142- . cloned ( )
143- }
144- RequestTime :: AtSlot ( slot) => {
145- // Get the state with slot equal to the lookup slot.
146- key_cache
147- . iter ( )
148- . rev ( ) // Usually the slot lies at the end of the map
149- . find ( |( k, _) | k. slot == slot)
150- . map ( |( _, v) | v)
151- . cloned ( )
152- }
153- }
154- }
155- None => None ,
156- }
109+ /// A collection of caches for various program state.
110+ pub struct CacheState {
111+ accumulator_messages_cache : AccumulatorMessagesCache ,
112+ wormhole_merkle_state_cache : WormholeMerkleStateCache ,
113+ message_cache : MessageCache ,
114+ cache_size : u64 ,
157115}
158116
159- impl Cache {
160- pub fn new ( cache_size : u64 ) -> Self {
117+ impl CacheState {
118+ pub fn new ( size : u64 ) -> Self {
161119 Self {
162- message_cache : Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ,
163- accumulator_messages_cache : Arc :: new ( RwLock :: new ( BTreeMap :: new ( ) ) ) ,
120+ accumulator_messages_cache : Arc :: new ( RwLock :: new ( BTreeMap :: new ( ) ) ) ,
164121 wormhole_merkle_state_cache : Arc :: new ( RwLock :: new ( BTreeMap :: new ( ) ) ) ,
165- cache_size,
122+ message_cache : Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ,
123+ cache_size : size,
166124 }
167125 }
168126}
169127
170- #[ async_trait:: async_trait]
171- pub trait AggregateCache {
128+ /// Allow downcasting State into CacheState for functions that depend on the `Cache` service.
129+ impl < ' a > From < & ' a State > for & ' a CacheState {
130+ fn from ( state : & ' a State ) -> & ' a CacheState {
131+ & state. cache
132+ }
133+ }
134+
135+ pub trait Cache {
172136 async fn message_state_keys ( & self ) -> Vec < MessageStateKey > ;
173137 async fn store_message_states ( & self , message_states : Vec < MessageState > ) -> Result < ( ) > ;
174138 async fn prune_removed_keys ( & self , current_keys : HashSet < MessageStateKey > ) ;
@@ -190,10 +154,13 @@ pub trait AggregateCache {
190154 async fn fetch_wormhole_merkle_state ( & self , slot : Slot ) -> Result < Option < WormholeMerkleState > > ;
191155}
192156
193- #[ async_trait:: async_trait]
194- impl AggregateCache for crate :: state:: State {
157+ impl < T > Cache for T
158+ where
159+ for < ' a > & ' a T : Into < & ' a CacheState > ,
160+ T : Sync ,
161+ {
195162 async fn message_state_keys ( & self ) -> Vec < MessageStateKey > {
196- self . cache
163+ self . into ( )
197164 . message_cache
198165 . read ( )
199166 . await
@@ -203,7 +170,7 @@ impl AggregateCache for crate::state::State {
203170 }
204171
205172 async fn store_message_states ( & self , message_states : Vec < MessageState > ) -> Result < ( ) > {
206- let mut message_cache = self . cache . message_cache . write ( ) . await ;
173+ let mut message_cache = self . into ( ) . message_cache . write ( ) . await ;
207174
208175 for message_state in message_states {
209176 let key = message_state. key ( ) ;
@@ -212,7 +179,7 @@ impl AggregateCache for crate::state::State {
212179 cache. insert ( time, message_state) ;
213180
214181 // Remove the earliest message states if the cache size is exceeded
215- while cache. len ( ) > self . cache . cache_size as usize {
182+ while cache. len ( ) > self . into ( ) . cache_size as usize {
216183 cache. pop_first ( ) ;
217184 }
218185 }
@@ -227,7 +194,7 @@ impl AggregateCache for crate::state::State {
227194 /// lose the cache for that key and cannot retrieve it for historical
228195 /// price queries.
229196 async fn prune_removed_keys ( & self , current_keys : HashSet < MessageStateKey > ) {
230- let mut message_cache = self . cache . message_cache . write ( ) . await ;
197+ let mut message_cache = self . into ( ) . message_cache . write ( ) . await ;
231198
232199 // Sometimes, some keys are removed from the accumulator. We track which keys are not
233200 // present in the message states and remove them from the cache.
@@ -262,7 +229,7 @@ impl AggregateCache for crate::state::State {
262229 feed_id : id,
263230 type_ : message_type,
264231 } ;
265- retrieve_message_state ( & self . cache , key, request_time. clone ( ) )
232+ retrieve_message_state ( self . into ( ) , key, request_time. clone ( ) )
266233 } )
267234 } ) )
268235 . await
@@ -275,60 +242,95 @@ impl AggregateCache for crate::state::State {
275242 & self ,
276243 accumulator_messages : AccumulatorMessages ,
277244 ) -> Result < ( ) > {
278- let mut cache = self . cache . accumulator_messages_cache . write ( ) . await ;
245+ let mut cache = self . into ( ) . accumulator_messages_cache . write ( ) . await ;
279246 cache. insert ( accumulator_messages. slot , accumulator_messages) ;
280- while cache. len ( ) > self . cache . cache_size as usize {
247+ while cache. len ( ) > self . into ( ) . cache_size as usize {
281248 cache. pop_first ( ) ;
282249 }
283250 Ok ( ( ) )
284251 }
285252
286253 async fn fetch_accumulator_messages ( & self , slot : Slot ) -> Result < Option < AccumulatorMessages > > {
287- let cache = self . cache . accumulator_messages_cache . read ( ) . await ;
254+ let cache = self . into ( ) . accumulator_messages_cache . read ( ) . await ;
288255 Ok ( cache. get ( & slot) . cloned ( ) )
289256 }
290257
291258 async fn store_wormhole_merkle_state (
292259 & self ,
293260 wormhole_merkle_state : WormholeMerkleState ,
294261 ) -> Result < ( ) > {
295- let mut cache = self . cache . wormhole_merkle_state_cache . write ( ) . await ;
262+ let mut cache = self . into ( ) . wormhole_merkle_state_cache . write ( ) . await ;
296263 cache. insert ( wormhole_merkle_state. root . slot , wormhole_merkle_state) ;
297- while cache. len ( ) > self . cache . cache_size as usize {
264+ while cache. len ( ) > self . into ( ) . cache_size as usize {
298265 cache. pop_first ( ) ;
299266 }
300267 Ok ( ( ) )
301268 }
302269
303270 async fn fetch_wormhole_merkle_state ( & self , slot : Slot ) -> Result < Option < WormholeMerkleState > > {
304- let cache = self . cache . wormhole_merkle_state_cache . read ( ) . await ;
271+ let cache = self . into ( ) . wormhole_merkle_state_cache . read ( ) . await ;
305272 Ok ( cache. get ( & slot) . cloned ( ) )
306273 }
307274}
308275
276+ async fn retrieve_message_state (
277+ cache : & CacheState ,
278+ key : MessageStateKey ,
279+ request_time : RequestTime ,
280+ ) -> Option < MessageState > {
281+ match cache. message_cache . read ( ) . await . get ( & key) {
282+ Some ( key_cache) => {
283+ match request_time {
284+ RequestTime :: Latest => key_cache. last_key_value ( ) . map ( |( _, v) | v) . cloned ( ) ,
285+ RequestTime :: FirstAfter ( time) => {
286+ // If the requested time is before the first element in the vector, we are
287+ // not sure that the first element is the closest one.
288+ if let Some ( ( _, oldest_record_value) ) = key_cache. first_key_value ( ) {
289+ if time < oldest_record_value. time ( ) . publish_time {
290+ return None ;
291+ }
292+ }
293+
294+ let lookup_time = MessageStateTime {
295+ publish_time : time,
296+ slot : 0 ,
297+ } ;
298+
299+ // Get the first element that is greater than or equal to the lookup time.
300+ key_cache
301+ . lower_bound ( Bound :: Included ( & lookup_time) )
302+ . peek_next ( )
303+ . map ( |( _, v) | v)
304+ . cloned ( )
305+ }
306+ RequestTime :: AtSlot ( slot) => {
307+ // Get the state with slot equal to the lookup slot.
308+ key_cache
309+ . iter ( )
310+ . rev ( ) // Usually the slot lies at the end of the map
311+ . find ( |( k, _) | k. slot == slot)
312+ . map ( |( _, v) | v)
313+ . cloned ( )
314+ }
315+ }
316+ }
317+ None => None ,
318+ }
319+ }
320+
309321#[ cfg( test) ]
310322mod test {
311323 use {
312324 super :: * ,
313325 crate :: {
314- aggregate:: {
315- wormhole_merkle:: {
316- WormholeMerkleMessageProof ,
317- WormholeMerkleState ,
318- } ,
319- AccumulatorMessages ,
320- ProofSet ,
321- } ,
326+ aggregate:: wormhole_merkle:: WormholeMerkleMessageProof ,
322327 state:: test:: setup_state,
323328 } ,
324329 pyth_sdk:: UnixTimestamp ,
325330 pythnet_sdk:: {
326331 accumulators:: merkle:: MerklePath ,
327332 hashers:: keccak256_160:: Keccak160 ,
328- messages:: {
329- Message ,
330- PriceFeedMessage ,
331- } ,
333+ messages:: PriceFeedMessage ,
332334 wire:: v1:: WormholeMerkleRoot ,
333335 } ,
334336 } ;
@@ -369,7 +371,7 @@ mod test {
369371 slot : Slot ,
370372 ) -> MessageState
371373 where
372- S : AggregateCache ,
374+ S : Cache ,
373375 {
374376 let message_state = create_dummy_price_feed_message_state ( feed_id, publish_time, slot) ;
375377 state
0 commit comments