diff --git a/arbitrum/recordingdb.go b/arbitrum/recordingdb.go index d3147807d0..a6f9818896 100644 --- a/arbitrum/recordingdb.go +++ b/arbitrum/recordingdb.go @@ -2,9 +2,11 @@ package arbitrum import ( "bytes" + "context" "encoding/hex" "errors" "fmt" + "sync" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/consensus" @@ -12,6 +14,7 @@ import ( "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/state" "github.com/ethereum/go-ethereum/core/types" + "github.com/ethereum/go-ethereum/core/vm" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/log" @@ -25,7 +28,7 @@ type RecordingKV struct { enableBypass bool } -func NewRecordingKV(inner *trie.Database) *RecordingKV { +func newRecordingKV(inner *trie.Database) *RecordingKV { return &RecordingKV{inner, make(map[common.Hash][]byte), false} } @@ -126,7 +129,7 @@ type RecordingChainContext struct { initialBlockNumber uint64 } -func NewRecordingChainContext(inner core.ChainContext, blocknumber uint64) *RecordingChainContext { +func newRecordingChainContext(inner core.ChainContext, blocknumber uint64) *RecordingChainContext { return &RecordingChainContext{ bc: inner, minBlockNumberAccessed: blocknumber, @@ -149,9 +152,83 @@ func (r *RecordingChainContext) GetMinBlockNumberAccessed() uint64 { return r.minBlockNumberAccessed } -func PrepareRecording(blockchain *core.BlockChain, lastBlockHeader *types.Header) (*state.StateDB, core.ChainContext, *RecordingKV, error) { - rawTrie := blockchain.StateCache().TrieDB() - recordingKeyValue := NewRecordingKV(rawTrie) +type RecordingDatabase struct { + db state.Database + bc *core.BlockChain + mutex sync.Mutex // protects StateFor and Dereference + references int64 +} + +func NewRecordingDatabase(ethdb ethdb.Database, blockchain *core.BlockChain) *RecordingDatabase { + return &RecordingDatabase{ + db: state.NewDatabaseWithConfig(ethdb, &trie.Config{Cache: 16}), //TODO cache needed? configurable? + bc: blockchain, + } +} + +// Normal geth state.New + Reference is not atomic vs Dereference. This one is. +// This function does not recreate a state +func (r *RecordingDatabase) StateFor(header *types.Header) (*state.StateDB, error) { + r.mutex.Lock() + defer r.mutex.Unlock() + + sdb, err := state.NewDeterministic(header.Root, r.db) + if err == nil { + r.referenceRootLockHeld(header.Root) + } + return sdb, err +} + +func (r *RecordingDatabase) Dereference(header *types.Header) { + if header != nil { + r.dereferenceRoot(header.Root) + } +} + +func (r *RecordingDatabase) WriteStateToDatabase(header *types.Header) error { + if header != nil { + return r.db.TrieDB().Commit(header.Root, true, nil) + } + return nil +} + +// lock must be held when calling that +func (r *RecordingDatabase) referenceRootLockHeld(root common.Hash) { + r.references++ + r.db.TrieDB().Reference(root, common.Hash{}) +} + +func (r *RecordingDatabase) dereferenceRoot(root common.Hash) { + r.mutex.Lock() + defer r.mutex.Unlock() + r.references-- + r.db.TrieDB().Dereference(root) +} + +func (r *RecordingDatabase) addStateVerify(statedb *state.StateDB, expected common.Hash) error { + r.mutex.Lock() + defer r.mutex.Unlock() + result, err := statedb.Commit(true) + if err != nil { + return err + } + if result != expected { + return fmt.Errorf("bad root hash expected: %v got: %v", expected, result) + } + r.referenceRootLockHeld(result) + return nil +} + +type StateBuildingLogFunction func(targetHeader, header *types.Header, hasState bool) + +func (r *RecordingDatabase) PrepareRecording(ctx context.Context, lastBlockHeader *types.Header, logFunc StateBuildingLogFunction) (*state.StateDB, core.ChainContext, *RecordingKV, error) { + _, err := r.GetOrRecreateState(ctx, lastBlockHeader, logFunc) + if err != nil { + return nil, nil, nil, err + } + finalDereference := lastBlockHeader // dereference in case of error + defer func() { r.Dereference(finalDereference) }() + recordingKeyValue := newRecordingKV(r.db.TrieDB()) recordingStateDatabase := state.NewDatabase(rawdb.NewDatabase(recordingKeyValue)) var prevRoot common.Hash @@ -167,29 +244,99 @@ func PrepareRecording(blockchain *core.BlockChain, lastBlockHeader *types.Header if !lastBlockHeader.Number.IsUint64() { return nil, nil, nil, errors.New("block number not uint64") } - recordingChainContext = NewRecordingChainContext(blockchain, lastBlockHeader.Number.Uint64()) + recordingChainContext = newRecordingChainContext(r.bc, lastBlockHeader.Number.Uint64()) } + finalDereference = nil return recordingStateDb, recordingChainContext, recordingKeyValue, nil } -func PreimagesFromRecording(chainContextIf core.ChainContext, recordingDb *RecordingKV) (map[common.Hash][]byte, error) { +func (r *RecordingDatabase) PreimagesFromRecording(chainContextIf core.ChainContext, recordingDb *RecordingKV) (map[common.Hash][]byte, error) { entries := recordingDb.GetRecordedEntries() recordingChainContext, ok := chainContextIf.(*RecordingChainContext) if (recordingChainContext == nil) || (!ok) { return nil, errors.New("recordingChainContext invalid") } - blockchain, ok := recordingChainContext.bc.(*core.BlockChain) - if (blockchain == nil) || (!ok) { - return nil, errors.New("blockchain invalid") - } + for i := recordingChainContext.GetMinBlockNumberAccessed(); i <= recordingChainContext.initialBlockNumber; i++ { - header := blockchain.GetHeaderByNumber(i) + header := r.bc.GetHeaderByNumber(i) hash := header.Hash() bytes, err := rlp.EncodeToBytes(header) if err != nil { - panic(fmt.Sprintf("Error RLP encoding header: %v\n", err)) + return nil, fmt.Errorf("Error RLP encoding header: %v\n", err) } entries[hash] = bytes } return entries, nil } + +func (r *RecordingDatabase) GetOrRecreateState(ctx context.Context, header *types.Header, logFunc StateBuildingLogFunction) (*state.StateDB, error) { + stateDb, err := r.StateFor(header) + if err == nil { + return stateDb, nil + } + returnedBlockNumber := header.Number.Uint64() + genesis := r.bc.Config().ArbitrumChainParams.GenesisBlockNum + currentHeader := header + var lastRoot common.Hash + for ctx.Err() == nil { + if logFunc != nil { + logFunc(header, currentHeader, false) + } + if currentHeader.Number.Uint64() <= genesis { + return nil, fmt.Errorf("moved beyond genesis looking for state looking for %d, genesis %d, err %w", returnedBlockNumber, genesis, err) + } + currentHeader = r.bc.GetHeader(currentHeader.ParentHash, currentHeader.Number.Uint64()-1) + if currentHeader == nil { + return nil, fmt.Errorf("chain doesn't contain parent of block %d hash %v", currentHeader.Number, currentHeader.Hash()) + } + stateDb, err = r.StateFor(currentHeader) + if err == nil { + lastRoot = currentHeader.Root + break + } + } + defer func() { + if (lastRoot != common.Hash{}) { + r.dereferenceRoot(lastRoot) + } + }() + blockToRecreate := currentHeader.Number.Uint64() + 1 + prevHash := currentHeader.Hash() + for ctx.Err() == nil { + block := r.bc.GetBlockByNumber(blockToRecreate) + if block == nil { + return nil, fmt.Errorf("block not found while recreating: %d", blockToRecreate) + } + if block.ParentHash() != prevHash { + return nil, fmt.Errorf("reorg detected: number %d expectedPrev: %v foundPrev: %v", blockToRecreate, prevHash, block.ParentHash()) + } + prevHash = block.Hash() + if logFunc != nil { + logFunc(header, block.Header(), true) + } + _, _, _, err := r.bc.Processor().Process(block, stateDb, vm.Config{}) + if err != nil { + return nil, fmt.Errorf("failed recreating state for block %d : %w", blockToRecreate, err) + } + err = r.addStateVerify(stateDb, block.Root()) + if err != nil { + return nil, fmt.Errorf("failed commiting state for block %d : %w", blockToRecreate, err) + } + r.dereferenceRoot(lastRoot) + lastRoot = block.Root() + if blockToRecreate >= returnedBlockNumber { + if block.Hash() != header.Hash() { + return nil, fmt.Errorf("blockHash doesn't match when recreating number: %d expected: %v got: %v", blockToRecreate, header.Hash(), block.Hash()) + } + // don't dereference this one + lastRoot = common.Hash{} + return stateDb, nil + } + blockToRecreate++ + } + return nil, ctx.Err() +} + +func (r *RecordingDatabase) ReferenceCount() int64 { + return r.references +} diff --git a/core/state/database.go b/core/state/database.go index ce5d8d7317..eaaf5606c3 100644 --- a/core/state/database.go +++ b/core/state/database.go @@ -19,6 +19,7 @@ package state import ( "errors" "fmt" + "runtime" "github.com/VictoriaMetrics/fastcache" "github.com/ethereum/go-ethereum/common" @@ -118,11 +119,13 @@ func NewDatabase(db ethdb.Database) Database { // large memory cache. func NewDatabaseWithConfig(db ethdb.Database, config *trie.Config) Database { csc, _ := lru.New(codeSizeCacheSize) - return &cachingDB{ + cdb := &cachingDB{ db: trie.NewDatabaseWithConfig(db, config), codeSizeCache: csc, codeCache: fastcache.New(codeCacheSize), } + runtime.SetFinalizer(cdb, (*cachingDB).finalizer) + return cdb } type cachingDB struct { @@ -140,6 +143,11 @@ func (db *cachingDB) OpenTrie(root common.Hash) (Trie, error) { return tr, nil } +// fastcache chunks are not mannaged by GC. +func (db *cachingDB) finalizer() { + db.codeCache.Reset() +} + // OpenStorageTrie opens the storage trie of an account. func (db *cachingDB) OpenStorageTrie(addrHash, root common.Hash) (Trie, error) { tr, err := trie.NewSecure(addrHash, root, db.db) diff --git a/trie/database.go b/trie/database.go index 9ac6dcec5c..650dbb81c2 100644 --- a/trie/database.go +++ b/trie/database.go @@ -297,6 +297,7 @@ func NewDatabaseWithConfig(diskdb ethdb.KeyValueStore, config *Config) *Database if config == nil || config.Preimages { // TODO(karalabe): Flip to default off in the future db.preimages = make(map[common.Hash][]byte) } + runtime.SetFinalizer(db, (*Database).finalizer) return db } @@ -305,6 +306,13 @@ func (db *Database) DiskDB() ethdb.KeyValueStore { return db.diskdb } +// must call Reset() to reclaim memory used by fastcache +func (db *Database) finalizer() { + if db.cleans != nil { + db.cleans.Reset() + } +} + // insert inserts a collapsed trie node into the memory database. // The blob size must be specified to allow proper size tracking. // All nodes inserted by this function will be reference tracked