diff --git a/codechain/codechain.yml b/codechain/codechain.yml index 43dba554be..5dcab05b39 100644 --- a/codechain/codechain.yml +++ b/codechain/codechain.yml @@ -259,6 +259,16 @@ args: takes_value: true conflicts_with: - no-discovery + - snapshot-hash: + long: snapshot-hash + value_name: HASH + requires: snapshot-number + takes_value: true + - snapshot-number: + long: snapshot-number + value_name: NUM + requires: snapshot-hash + takes_value: true - no-snapshot: long: no-snapshot help: Disable snapshots diff --git a/codechain/config/mod.rs b/codechain/config/mod.rs index a75df5cdeb..63b01fb928 100644 --- a/codechain/config/mod.rs +++ b/codechain/config/mod.rs @@ -25,6 +25,7 @@ use cidr::IpCidr; use ckey::PlatformAddress; use clap; use cnetwork::{FilterEntry, NetworkConfig, SocketAddr}; +use primitives::H256; use toml; pub use self::chain_type::ChainType; @@ -242,6 +243,8 @@ pub struct Network { pub min_peers: Option, pub max_peers: Option, pub sync: Option, + pub snapshot_hash: Option, + pub snapshot_number: Option, pub transaction_relay: Option, pub discovery: Option, pub discovery_type: Option, @@ -500,6 +503,12 @@ impl Network { if other.sync.is_some() { self.sync = other.sync; } + if other.snapshot_hash.is_some() { + self.snapshot_hash = other.snapshot_hash; + } + if other.snapshot_number.is_some() { + self.snapshot_number = other.snapshot_number; + } if other.transaction_relay.is_some() { self.transaction_relay = other.transaction_relay; } @@ -552,6 +561,12 @@ impl Network { if matches.is_present("no-sync") { self.sync = Some(false); } + if let Some(snapshot_hash) = matches.value_of("snapshot-hash") { + self.snapshot_hash = Some(snapshot_hash.parse().map_err(|_| "Invalid snapshot-hash")?); + } + if let Some(snapshot_number) = matches.value_of("snapshot-number") { + self.snapshot_number = Some(snapshot_number.parse().map_err(|_| "Invalid snapshot-number")?); + } if matches.is_present("no-tx-relay") { self.transaction_relay = Some(false); } diff --git a/codechain/run_node.rs b/codechain/run_node.rs index 86a5d7b6ab..acb37e572c 100644 --- a/codechain/run_node.rs +++ b/codechain/run_node.rs @@ -300,7 +300,11 @@ pub fn run_node(matches: &ArgMatches) -> Result<(), String> { if config.network.sync.unwrap() { let sync_sender = { let client = client.client(); - service.register_extension(move |api| BlockSyncExtension::new(client, api)) + let snapshot_target = match (config.network.snapshot_hash, config.network.snapshot_number) { + (Some(hash), Some(num)) => Some((hash, num)), + _ => None, + }; + service.register_extension(move |api| BlockSyncExtension::new(client, api, snapshot_target)) }; let sync = Arc::new(BlockSyncSender::from(sync_sender.clone())); client.client().add_notify(Arc::downgrade(&sync) as Weak); diff --git a/sync/Cargo.toml b/sync/Cargo.toml index 86db229173..40f2303e4c 100644 --- a/sync/Cargo.toml +++ b/sync/Cargo.toml @@ -14,6 +14,7 @@ codechain-network = { path = "../network" } codechain-state = { path = "../state" } codechain-timer = { path = "../util/timer" } codechain-types = { path = "../types" } +hashdb = { path = "../util/hashdb" } journaldb = { path = "../util/journaldb" } kvdb = { path = "../util/kvdb" } log = "0.4.6" @@ -28,7 +29,6 @@ token-generator = "0.1.0" util-error = { path = "../util/error" } [dev-dependencies] -hashdb = { path = "../util/hashdb" } kvdb-memorydb = { path = "../util/kvdb-memorydb" } tempfile = "3.0.4" trie-standardmap = { path = "../util/trie-standardmap" } diff --git a/sync/src/block/extension.rs b/sync/src/block/extension.rs index f089980c29..76c6c8c957 100644 --- a/sync/src/block/extension.rs +++ b/sync/src/block/extension.rs @@ -24,12 +24,14 @@ use ccore::{ Block, BlockChainClient, BlockChainTrait, BlockId, BlockImportError, ChainNotify, Client, ImportBlock, ImportError, UnverifiedTransaction, }; +use cmerkle::TrieFactory; use cnetwork::{Api, EventSender, NetworkExtension, NodeId}; use cstate::FindActionHandler; use ctimer::TimerToken; use ctypes::header::{Header, Seal}; use ctypes::transaction::Action; use ctypes::{BlockHash, BlockNumber}; +use hashdb::AsHashDB; use primitives::{H256, U256}; use rand::prelude::SliceRandom; use rand::thread_rng; @@ -55,7 +57,14 @@ pub struct TokenInfo { request_id: Option, } +enum State { + SnapshotHeader(H256), + SnapshotChunk(H256), + Full, +} + pub struct Extension { + state: State, requests: HashMap>, connected_nodes: HashSet, header_downloaders: HashMap, @@ -69,9 +78,22 @@ pub struct Extension { } impl Extension { - pub fn new(client: Arc, api: Box) -> Extension { + pub fn new(client: Arc, api: Box, snapshot_target: Option<(H256, u64)>) -> Extension { api.set_timer(SYNC_TIMER_TOKEN, Duration::from_millis(SYNC_TIMER_INTERVAL)).expect("Timer set succeeds"); + let state = match snapshot_target { + Some((hash, num)) => match client.block_header(&BlockId::Number(num)) { + Some(ref header) if *header.hash() == hash => { + let state_db = client.state_db().read(); + match TrieFactory::readonly(state_db.as_hashdb(), &header.state_root()) { + Ok(ref trie) if trie.is_complete() => State::Full, + _ => State::SnapshotChunk(*header.hash()), + } + } + _ => State::SnapshotHeader(hash), + }, + None => State::Full, + }; let mut header = client.best_header(); let mut hollow_headers = vec![header.decode()]; while client.block_body(&BlockId::Hash(header.hash())).is_none() { @@ -89,6 +111,7 @@ impl Extension { } cinfo!(SYNC, "Sync extension initialized"); Extension { + state, requests: Default::default(), connected_nodes: Default::default(), header_downloaders: Default::default(), @@ -286,31 +309,35 @@ impl NetworkExtension for Extension { fn on_timeout(&mut self, token: TimerToken) { match token { - SYNC_TIMER_TOKEN => { - let best_proposal_score = self.client.chain_info().best_proposal_score; - let mut peer_ids: Vec<_> = self.header_downloaders.keys().cloned().collect(); - peer_ids.shuffle(&mut thread_rng()); - - for id in &peer_ids { - let request = self.header_downloaders.get_mut(id).and_then(HeaderDownloader::create_request); - if let Some(request) = request { - self.send_header_request(id, request); - break + SYNC_TIMER_TOKEN => match self.state { + State::SnapshotHeader(..) => unimplemented!(), + State::SnapshotChunk(..) => unimplemented!(), + State::Full => { + let best_proposal_score = self.client.chain_info().best_proposal_score; + let mut peer_ids: Vec<_> = self.header_downloaders.keys().cloned().collect(); + peer_ids.shuffle(&mut thread_rng()); + + for id in &peer_ids { + let request = self.header_downloaders.get_mut(id).and_then(HeaderDownloader::create_request); + if let Some(request) = request { + self.send_header_request(id, request); + break + } } - } - for id in peer_ids { - let peer_score = if let Some(peer) = self.header_downloaders.get(&id) { - peer.total_score() - } else { - U256::zero() - }; + for id in peer_ids { + let peer_score = if let Some(peer) = self.header_downloaders.get(&id) { + peer.total_score() + } else { + U256::zero() + }; - if peer_score > best_proposal_score { - self.send_body_request(&id); + if peer_score > best_proposal_score { + self.send_body_request(&id); + } } } - } + }, SYNC_EXPIRE_TOKEN_BEGIN..=SYNC_EXPIRE_TOKEN_END => { self.check_sync_variable(); let (id, request_id) = { @@ -572,33 +599,37 @@ impl Extension { return } - match response { - ResponseMessage::Headers(headers) => { - self.dismiss_request(from, id); - self.on_header_response(from, &headers) - } - ResponseMessage::Bodies(bodies) => { - self.check_sync_variable(); - let hashes = match request { - RequestMessage::Bodies(hashes) => hashes, - _ => unreachable!(), - }; - assert_eq!(bodies.len(), hashes.len()); - if let Some(token) = self.tokens.get(from) { - if let Some(token_info) = self.tokens_info.get_mut(token) { - if token_info.request_id.is_none() { - ctrace!(SYNC, "Expired before handling response"); - return + match self.state { + State::SnapshotHeader(..) => unimplemented!(), + State::SnapshotChunk(..) => unimplemented!(), + State::Full => match response { + ResponseMessage::Headers(headers) => { + self.dismiss_request(from, id); + self.on_header_response(from, &headers) + } + ResponseMessage::Bodies(bodies) => { + self.check_sync_variable(); + let hashes = match request { + RequestMessage::Bodies(hashes) => hashes, + _ => unreachable!(), + }; + assert_eq!(bodies.len(), hashes.len()); + if let Some(token) = self.tokens.get(from) { + if let Some(token_info) = self.tokens_info.get_mut(token) { + if token_info.request_id.is_none() { + ctrace!(SYNC, "Expired before handling response"); + return + } + self.api.clear_timer(*token).expect("Timer clear succeed"); + token_info.request_id = None; } - self.api.clear_timer(*token).expect("Timer clear succeed"); - token_info.request_id = None; } + self.dismiss_request(from, id); + self.on_body_response(hashes, bodies); + self.check_sync_variable(); } - self.dismiss_request(from, id); - self.on_body_response(hashes, bodies); - self.check_sync_variable(); - } - _ => unimplemented!(), + _ => unimplemented!(), + }, } } } diff --git a/sync/src/lib.rs b/sync/src/lib.rs index d67fc4c8c7..b89deb036d 100644 --- a/sync/src/lib.rs +++ b/sync/src/lib.rs @@ -25,7 +25,6 @@ extern crate codechain_state as cstate; extern crate codechain_timer as ctimer; extern crate codechain_types as ctypes; -#[cfg(test)] extern crate hashdb; extern crate journaldb; extern crate kvdb; diff --git a/util/merkle/src/triedb.rs b/util/merkle/src/triedb.rs index c65ce58c6a..8f43f1c356 100644 --- a/util/merkle/src/triedb.rs +++ b/util/merkle/src/triedb.rs @@ -14,7 +14,7 @@ // You should have received a copy of the GNU Affero General Public License // along with this program. If not, see . -use ccrypto::blake256; +use ccrypto::{blake256, BLAKE_NULL_RLP}; use hashdb::HashDB; use primitives::H256; @@ -105,6 +105,26 @@ impl<'db> TrieDB<'db> { None => Ok(None), } } + + /// Check if every leaf of the trie exists + pub fn is_complete(&self) -> bool { + *self.root == BLAKE_NULL_RLP || self.is_complete_aux(self.root) + } + + /// Check if every leaf of the trie starting from `hash` exists + pub fn is_complete_aux(&self, hash: &H256) -> bool { + if let Some(node_rlp) = self.db.get(hash) { + match RlpNode::decoded(node_rlp.as_ref()) { + Some(RlpNode::Branch(.., children)) => { + children.iter().flatten().all(|child| self.is_complete_aux(child)) + } + Some(RlpNode::Leaf(..)) => true, + None => false, + } + } else { + false + } + } } impl<'db> Trie for TrieDB<'db> { @@ -126,6 +146,19 @@ mod tests { use crate::*; use memorydb::*; + fn delete_any_child(db: &mut MemoryDB, root: &H256) { + let node_rlp = db.get(root).unwrap(); + match RlpNode::decoded(&node_rlp).unwrap() { + RlpNode::Leaf(..) => { + db.remove(root); + } + RlpNode::Branch(.., children) => { + let first_child = children.iter().find(|c| c.is_some()).unwrap().unwrap(); + db.remove(&first_child); + } + } + } + #[test] fn get() { let mut memdb = MemoryDB::new(); @@ -141,4 +174,33 @@ mod tests { assert_eq!(t.get(b"B"), Ok(Some(DBValue::from_slice(b"ABCBA")))); assert_eq!(t.get(b"C"), Ok(None)); } + + #[test] + fn is_complete_success() { + let mut memdb = MemoryDB::new(); + let mut root = H256::new(); + { + let mut t = TrieDBMut::new(&mut memdb, &mut root); + t.insert(b"A", b"ABC").unwrap(); + t.insert(b"B", b"ABCBA").unwrap(); + } + + let t = TrieDB::try_new(&memdb, &root).unwrap(); + assert!(t.is_complete()); + } + + #[test] + fn is_complete_fail() { + let mut memdb = MemoryDB::new(); + let mut root = H256::new(); + { + let mut t = TrieDBMut::new(&mut memdb, &mut root); + t.insert(b"A", b"ABC").unwrap(); + t.insert(b"B", b"ABCBA").unwrap(); + } + delete_any_child(&mut memdb, &root); + + let t = TrieDB::try_new(&memdb, &root).unwrap(); + assert!(!t.is_complete()); + } }