Skip to content

Commit 4ae7f28

Browse files
committed
Load leaves and compare in-memory
1 parent a043d2f commit 4ae7f28

File tree

2 files changed

+68
-22
lines changed

2 files changed

+68
-22
lines changed

trie/zk_trie.go

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,15 +234,15 @@ func VerifyProofSMT(rootHash common.Hash, key []byte, proofDb ethdb.KeyValueRead
234234
}
235235
}
236236

237-
func (t *ZkTrie) CountLeaves() uint64 {
237+
func (t *ZkTrie) CountLeaves(cb func(key, value []byte)) uint64 {
238238
root, err := t.ZkTrie.Tree().Root()
239239
if err != nil {
240240
panic("CountLeaves cannot get root")
241241
}
242-
return t.countLeaves(root)
242+
return t.countLeaves(root, cb, 0)
243243
}
244244

245-
func (t *ZkTrie) countLeaves(root *zkt.Hash) uint64 {
245+
func (t *ZkTrie) countLeaves(root *zkt.Hash, cb func(key, value []byte), depth int) uint64 {
246246
if root == nil {
247247
return 0
248248
}
@@ -253,9 +253,23 @@ func (t *ZkTrie) countLeaves(root *zkt.Hash) uint64 {
253253
}
254254

255255
if rootNode.Type == zktrie.NodeTypeLeaf_New {
256+
cb(append([]byte{}, rootNode.NodeKey.Bytes()...), append([]byte{}, rootNode.Data()...))
256257
return 1
257258
} else {
258-
return t.countLeaves(rootNode.ChildL) + t.countLeaves(rootNode.ChildR)
259+
count := make(chan uint64)
260+
if depth < 5 {
261+
leftT := t.Copy()
262+
rightT := t.Copy()
263+
go func() {
264+
count <- leftT.countLeaves(rootNode.ChildL, cb, depth+1)
265+
}()
266+
go func() {
267+
count <- rightT.countLeaves(rootNode.ChildR, cb, depth+1)
268+
}()
269+
return <-count + <-count
270+
} else {
271+
return t.countLeaves(rootNode.ChildL, cb, depth+1) + t.countLeaves(rootNode.ChildR, cb, depth+1)
272+
}
259273
}
260274
}
261275

trie/zk_trie_test.go

Lines changed: 50 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package trie
1919
import (
2020
"bytes"
2121
"encoding/binary"
22+
"encoding/hex"
2223
"io/ioutil"
2324
"os"
2425
"runtime"
@@ -33,6 +34,7 @@ import (
3334
"github.com/scroll-tech/go-ethereum/common"
3435
"github.com/scroll-tech/go-ethereum/core/rawdb"
3536
"github.com/scroll-tech/go-ethereum/core/types"
37+
"github.com/scroll-tech/go-ethereum/crypto"
3638
"github.com/scroll-tech/go-ethereum/ethdb/leveldb"
3739
"github.com/scroll-tech/go-ethereum/ethdb/memorydb"
3840
"github.com/scroll-tech/go-ethereum/rlp"
@@ -292,28 +294,58 @@ type dbs struct {
292294
mptDb *leveldb.Database
293295
}
294296

295-
var accountsLeft = -1
297+
var accountsDone = 0
296298

297299
func checkTrieEquality(t *testing.T, dbs *dbs, zkRoot, mptRoot common.Hash, leafChecker func(*testing.T, *dbs, []byte, []byte)) {
298-
zkTrie, err := NewZkTrie(zkRoot, NewZktrieDatabase(dbs.zkDb))
300+
zkTrie, err := NewZkTrie(zkRoot, NewZktrieDatabaseFromTriedb(NewDatabaseWithConfig(dbs.zkDb, &Config{Preimages: true})))
299301
require.NoError(t, err)
300302

301303
mptTrie, err := NewSecure(mptRoot, NewDatabaseWithConfig(dbs.mptDb, &Config{Preimages: true}))
302304
require.NoError(t, err)
303305

304-
expectedLeaves := zkTrie.CountLeaves()
305-
trieIt := NewIterator(mptTrie.NodeIterator(nil))
306-
if accountsLeft == -1 {
307-
accountsLeft = int(expectedLeaves)
306+
dup := func(s []byte) []byte {
307+
return append([]byte{}, s...)
308308
}
309309

310-
for trieIt.Next() {
311-
expectedLeaves--
312-
preimageKey := mptTrie.GetKey(trieIt.Key)
313-
require.NotEmpty(t, preimageKey)
314-
leafChecker(t, dbs, zkTrie.Get(preimageKey), mptTrie.Get(preimageKey))
310+
mptLeafMap := make(map[string][]byte, 1000)
311+
trieIt := NewIterator(mptTrie.NodeIterator(nil))
312+
mptDone := make(chan struct{})
313+
go func() {
314+
for trieIt.Next() {
315+
mptLeafMap[string(dup(trieIt.Key))] = dup(trieIt.Value)
316+
if len(mptLeafMap)%10000 == 0 {
317+
t.Log("MPT Accounts Loaded:", len(mptLeafMap))
318+
}
319+
}
320+
close(mptDone)
321+
}()
322+
323+
zkLeafMap := make(map[string][]byte, 1000)
324+
var zkLeafMutex sync.Mutex
325+
zkDone := make(chan struct{})
326+
go func() {
327+
zkTrie.CountLeaves(func(key, value []byte) {
328+
preimageKey := zkTrie.GetKey(key)
329+
require.NotEmpty(t, preimageKey)
330+
zkLeafMutex.Lock()
331+
zkLeafMap[string(dup(preimageKey))] = value
332+
zkLeafMutex.Unlock()
333+
if len(zkLeafMap)%10000 == 0 {
334+
t.Log("ZK Accounts Loaded:", len(zkLeafMap))
335+
}
336+
})
337+
close(zkDone)
338+
}()
339+
340+
<-zkDone
341+
<-mptDone
342+
require.Equal(t, len(mptLeafMap), len(zkLeafMap))
343+
for preimageKey, zkValue := range zkLeafMap {
344+
mptKey := crypto.Keccak256([]byte(preimageKey))
345+
mptVal, ok := mptLeafMap[string(mptKey)]
346+
require.True(t, ok, "key %s not found in mpt", hex.EncodeToString([]byte(preimageKey)))
347+
leafChecker(t, dbs, zkValue, mptVal)
315348
}
316-
require.Zero(t, expectedLeaves)
317349
}
318350

319351
func checkAccountEquality(t *testing.T, dbs *dbs, zkAccountBytes, mptAccountBytes []byte) {
@@ -322,18 +354,18 @@ func checkAccountEquality(t *testing.T, dbs *dbs, zkAccountBytes, mptAccountByte
322354
zkAccount, err := types.UnmarshalStateAccount(zkAccountBytes)
323355
require.NoError(t, err)
324356

325-
require.Equal(t, mptAccount.Nonce, zkAccount.Nonce)
326-
require.True(t, mptAccount.Balance.Cmp(zkAccount.Balance) == 0)
327-
require.Equal(t, mptAccount.KeccakCodeHash, zkAccount.KeccakCodeHash)
357+
require.Equal(t, mptAccount.Nonce, zkAccount.Nonce, "nonce zk: %d, mpt: %d", zkAccount.Nonce, mptAccount.Nonce)
358+
require.True(t, mptAccount.Balance.Cmp(zkAccount.Balance) == 0, "balance zk: %s, mpt: %s", zkAccount.Balance.String(), mptAccount.Balance.String())
359+
require.Equal(t, mptAccount.KeccakCodeHash, zkAccount.KeccakCodeHash, "code hash zk: %s, mpt: %s", hex.EncodeToString(zkAccount.KeccakCodeHash), hex.EncodeToString(mptAccount.KeccakCodeHash))
328360
checkTrieEquality(t, dbs, common.BytesToHash(zkAccount.Root[:]), common.BytesToHash(mptAccount.Root[:]), checkStorageEquality)
329-
accountsLeft--
330-
t.Log("Accounts left:", accountsLeft)
361+
accountsDone++
362+
t.Log("Accounts done:", accountsDone)
331363
}
332364

333365
func checkStorageEquality(t *testing.T, _ *dbs, zkStorageBytes, mptStorageBytes []byte) {
334366
zkValue := common.BytesToHash(zkStorageBytes)
335367
_, content, _, err := rlp.Split(mptStorageBytes)
336368
require.NoError(t, err)
337369
mptValue := common.BytesToHash(content)
338-
require.Equal(t, zkValue, mptValue)
370+
require.Equal(t, zkValue, mptValue, "storage zk: %s, mpt: %s", zkValue.Hex(), mptValue.Hex())
339371
}

0 commit comments

Comments
 (0)