@@ -24,12 +24,14 @@ use ccore::{
2424 Block , BlockChainClient , BlockChainTrait , BlockId , BlockImportError , ChainNotify , Client , ImportBlock , ImportError ,
2525 UnverifiedTransaction ,
2626} ;
27+ use cmerkle:: TrieFactory ;
2728use cnetwork:: { Api , EventSender , NetworkExtension , NodeId } ;
2829use cstate:: FindActionHandler ;
2930use ctimer:: TimerToken ;
3031use ctypes:: header:: { Header , Seal } ;
3132use ctypes:: transaction:: Action ;
3233use ctypes:: { BlockHash , BlockNumber } ;
34+ use hashdb:: AsHashDB ;
3335use primitives:: { H256 , U256 } ;
3436use rand:: prelude:: SliceRandom ;
3537use rand:: thread_rng;
@@ -55,7 +57,14 @@ pub struct TokenInfo {
5557 request_id : Option < u64 > ,
5658}
5759
60+ enum State {
61+ SnapshotHeader ( H256 ) ,
62+ SnapshotChunk ( H256 ) ,
63+ Full ,
64+ }
65+
5866pub struct Extension {
67+ state : State ,
5968 requests : HashMap < NodeId , Vec < ( u64 , RequestMessage ) > > ,
6069 connected_nodes : HashSet < NodeId > ,
6170 header_downloaders : HashMap < NodeId , HeaderDownloader > ,
@@ -69,9 +78,22 @@ pub struct Extension {
6978}
7079
7180impl Extension {
72- pub fn new ( client : Arc < Client > , api : Box < dyn Api > , _snapshot_target : Option < ( H256 , u64 ) > ) -> Extension {
81+ pub fn new ( client : Arc < Client > , api : Box < dyn Api > , snapshot_target : Option < ( H256 , u64 ) > ) -> Extension {
7382 api. set_timer ( SYNC_TIMER_TOKEN , Duration :: from_millis ( SYNC_TIMER_INTERVAL ) ) . expect ( "Timer set succeeds" ) ;
7483
84+ let state = match snapshot_target {
85+ Some ( ( hash, num) ) => match client. block_header ( & BlockId :: Number ( num) ) {
86+ Some ( ref header) if * header. hash ( ) == hash => {
87+ let state_db = client. state_db ( ) . read ( ) ;
88+ match TrieFactory :: readonly ( state_db. as_hashdb ( ) , & header. state_root ( ) ) {
89+ Ok ( ref trie) if trie. is_complete ( ) => State :: Full ,
90+ _ => State :: SnapshotChunk ( * header. hash ( ) ) ,
91+ }
92+ }
93+ _ => State :: SnapshotHeader ( hash) ,
94+ } ,
95+ None => State :: Full ,
96+ } ;
7597 let mut header = client. best_header ( ) ;
7698 let mut hollow_headers = vec ! [ header. decode( ) ] ;
7799 while client. block_body ( & BlockId :: Hash ( header. hash ( ) ) ) . is_none ( ) {
@@ -89,6 +111,7 @@ impl Extension {
89111 }
90112 cinfo ! ( SYNC , "Sync extension initialized" ) ;
91113 Extension {
114+ state,
92115 requests : Default :: default ( ) ,
93116 connected_nodes : Default :: default ( ) ,
94117 header_downloaders : Default :: default ( ) ,
@@ -286,31 +309,35 @@ impl NetworkExtension<Event> for Extension {
286309
287310 fn on_timeout ( & mut self , token : TimerToken ) {
288311 match token {
289- SYNC_TIMER_TOKEN => {
290- let best_proposal_score = self . client . chain_info ( ) . best_proposal_score ;
291- let mut peer_ids: Vec < _ > = self . header_downloaders . keys ( ) . cloned ( ) . collect ( ) ;
292- peer_ids. shuffle ( & mut thread_rng ( ) ) ;
293-
294- for id in & peer_ids {
295- let request = self . header_downloaders . get_mut ( id) . and_then ( HeaderDownloader :: create_request) ;
296- if let Some ( request) = request {
297- self . send_header_request ( id, request) ;
298- break
312+ SYNC_TIMER_TOKEN => match self . state {
313+ State :: SnapshotHeader ( ..) => unimplemented ! ( ) ,
314+ State :: SnapshotChunk ( ..) => unimplemented ! ( ) ,
315+ State :: Full => {
316+ let best_proposal_score = self . client . chain_info ( ) . best_proposal_score ;
317+ let mut peer_ids: Vec < _ > = self . header_downloaders . keys ( ) . cloned ( ) . collect ( ) ;
318+ peer_ids. shuffle ( & mut thread_rng ( ) ) ;
319+
320+ for id in & peer_ids {
321+ let request = self . header_downloaders . get_mut ( id) . and_then ( HeaderDownloader :: create_request) ;
322+ if let Some ( request) = request {
323+ self . send_header_request ( id, request) ;
324+ break
325+ }
299326 }
300- }
301327
302- for id in peer_ids {
303- let peer_score = if let Some ( peer) = self . header_downloaders . get ( & id) {
304- peer. total_score ( )
305- } else {
306- U256 :: zero ( )
307- } ;
328+ for id in peer_ids {
329+ let peer_score = if let Some ( peer) = self . header_downloaders . get ( & id) {
330+ peer. total_score ( )
331+ } else {
332+ U256 :: zero ( )
333+ } ;
308334
309- if peer_score > best_proposal_score {
310- self . send_body_request ( & id) ;
335+ if peer_score > best_proposal_score {
336+ self . send_body_request ( & id) ;
337+ }
311338 }
312339 }
313- }
340+ } ,
314341 SYNC_EXPIRE_TOKEN_BEGIN ..=SYNC_EXPIRE_TOKEN_END => {
315342 self . check_sync_variable ( ) ;
316343 let ( id, request_id) = {
@@ -572,33 +599,37 @@ impl Extension {
572599 return
573600 }
574601
575- match response {
576- ResponseMessage :: Headers ( headers) => {
577- self . dismiss_request ( from, id) ;
578- self . on_header_response ( from, & headers)
579- }
580- ResponseMessage :: Bodies ( bodies) => {
581- self . check_sync_variable ( ) ;
582- let hashes = match request {
583- RequestMessage :: Bodies ( hashes) => hashes,
584- _ => unreachable ! ( ) ,
585- } ;
586- assert_eq ! ( bodies. len( ) , hashes. len( ) ) ;
587- if let Some ( token) = self . tokens . get ( from) {
588- if let Some ( token_info) = self . tokens_info . get_mut ( token) {
589- if token_info. request_id . is_none ( ) {
590- ctrace ! ( SYNC , "Expired before handling response" ) ;
591- return
602+ match self . state {
603+ State :: SnapshotHeader ( ..) => unimplemented ! ( ) ,
604+ State :: SnapshotChunk ( ..) => unimplemented ! ( ) ,
605+ State :: Full => match response {
606+ ResponseMessage :: Headers ( headers) => {
607+ self . dismiss_request ( from, id) ;
608+ self . on_header_response ( from, & headers)
609+ }
610+ ResponseMessage :: Bodies ( bodies) => {
611+ self . check_sync_variable ( ) ;
612+ let hashes = match request {
613+ RequestMessage :: Bodies ( hashes) => hashes,
614+ _ => unreachable ! ( ) ,
615+ } ;
616+ assert_eq ! ( bodies. len( ) , hashes. len( ) ) ;
617+ if let Some ( token) = self . tokens . get ( from) {
618+ if let Some ( token_info) = self . tokens_info . get_mut ( token) {
619+ if token_info. request_id . is_none ( ) {
620+ ctrace ! ( SYNC , "Expired before handling response" ) ;
621+ return
622+ }
623+ self . api . clear_timer ( * token) . expect ( "Timer clear succeed" ) ;
624+ token_info. request_id = None ;
592625 }
593- self . api . clear_timer ( * token) . expect ( "Timer clear succeed" ) ;
594- token_info. request_id = None ;
595626 }
627+ self . dismiss_request ( from, id) ;
628+ self . on_body_response ( hashes, bodies) ;
629+ self . check_sync_variable ( ) ;
596630 }
597- self . dismiss_request ( from, id) ;
598- self . on_body_response ( hashes, bodies) ;
599- self . check_sync_variable ( ) ;
600- }
601- _ => unimplemented ! ( ) ,
631+ _ => unimplemented ! ( ) ,
632+ } ,
602633 }
603634 }
604635 }
0 commit comments