diff --git a/src/main/scala/io/iohk/ethereum/forkid/ForkIdValidator.scala b/src/main/scala/io/iohk/ethereum/forkid/ForkIdValidator.scala index 5cd8b061f0..860a7a08ce 100644 --- a/src/main/scala/io/iohk/ethereum/forkid/ForkIdValidator.scala +++ b/src/main/scala/io/iohk/ethereum/forkid/ForkIdValidator.scala @@ -22,9 +22,12 @@ case object Connect extends ForkIdValidationResult case object ErrRemoteStale extends ForkIdValidationResult case object ErrLocalIncompatibleOrStale extends ForkIdValidationResult +import cats.effect._ + object ForkIdValidator { - implicit val unsafeLogger: SelfAwareStructuredLogger[Task] = Slf4jLogger.getLogger[Task] + implicit val taskLogger: SelfAwareStructuredLogger[Task] = Slf4jLogger.getLogger[Task] + implicit val syncIoLogger: SelfAwareStructuredLogger[SyncIO] = Slf4jLogger.getLogger[SyncIO] val maxUInt64: BigInt = (BigInt(0x7fffffffffffffffL) << 1) + 1 // scalastyle:ignore magic.number diff --git a/src/main/scala/io/iohk/ethereum/network/handshaker/EthNodeStatus64ExchangeState.scala b/src/main/scala/io/iohk/ethereum/network/handshaker/EthNodeStatus64ExchangeState.scala index 5e2ea96a44..4118e42fe4 100644 --- a/src/main/scala/io/iohk/ethereum/network/handshaker/EthNodeStatus64ExchangeState.scala +++ b/src/main/scala/io/iohk/ethereum/network/handshaker/EthNodeStatus64ExchangeState.scala @@ -1,12 +1,17 @@ package io.iohk.ethereum.network.handshaker +import cats.effect.SyncIO + +import io.iohk.ethereum.forkid.Connect import io.iohk.ethereum.forkid.ForkId +import io.iohk.ethereum.forkid.ForkIdValidator import io.iohk.ethereum.network.EtcPeerManagerActor.PeerInfo import io.iohk.ethereum.network.EtcPeerManagerActor.RemoteStatus import io.iohk.ethereum.network.p2p.Message import io.iohk.ethereum.network.p2p.MessageSerializable import io.iohk.ethereum.network.p2p.messages.Capability import io.iohk.ethereum.network.p2p.messages.ETH64 +import io.iohk.ethereum.network.p2p.messages.WireProtocol.Disconnect case class EthNodeStatus64ExchangeState( handshakerConfiguration: EtcHandshakerConfiguration @@ -15,8 +20,17 @@ case class EthNodeStatus64ExchangeState( import handshakerConfiguration._ def applyResponseMessage: PartialFunction[Message, HandshakerState[PeerInfo]] = { case status: ETH64.Status => - // TODO: validate fork id of the remote peer - applyRemoteStatusMessage(RemoteStatus(status)) + import ForkIdValidator.syncIoLogger + (for { + validationResult <- + ForkIdValidator.validatePeer[SyncIO](blockchainReader.genesisHeader.hash, blockchainConfig)( + blockchainReader.getBestBlockNumber(), + status.forkId + ) + } yield validationResult match { + case Connect => applyRemoteStatusMessage(RemoteStatus(status)) + case _ => DisconnectedState[PeerInfo](Disconnect.Reasons.UselessPeer) + }).unsafeRunSync() } override protected def createStatusMsg(): MessageSerializable = { diff --git a/src/test/scala/io/iohk/ethereum/network/handshaker/EtcHandshakerSpec.scala b/src/test/scala/io/iohk/ethereum/network/handshaker/EtcHandshakerSpec.scala index 461aff25a9..47a947ee5d 100644 --- a/src/test/scala/io/iohk/ethereum/network/handshaker/EtcHandshakerSpec.scala +++ b/src/test/scala/io/iohk/ethereum/network/handshaker/EtcHandshakerSpec.scala @@ -190,7 +190,8 @@ class EtcHandshakerSpec extends AnyFlatSpec with Matchers { } } - it should "send status with fork id when peer supports ETH64" in new LocalPeerETH64Setup with RemotePeerETH64Setup { + it should "connect correctly after validating fork id when peer supports ETH64" in new LocalPeerETH64Setup + with RemotePeerETH64Setup { val newChainWeight = ChainWeight.zero.increase(genesisBlock.header).increase(firstBlock.header) @@ -223,6 +224,45 @@ class EtcHandshakerSpec extends AnyFlatSpec with Matchers { } } + it should "disconnect from a useless peer after validating fork id when peer supports ETH64" in new LocalPeerETH64Setup + with RemotePeerETH64Setup { + + val newChainWeight = ChainWeight.zero.increase(genesisBlock.header).increase(firstBlock.header) + + blockchainWriter.save(firstBlock, Nil, newChainWeight, saveAsBestBlock = true) + + val newLocalStatusMsg = + localStatusMsg + .copy( + bestHash = firstBlock.header.hash, + totalDifficulty = newChainWeight.totalDifficulty, + forkId = ForkId(0xfc64ec04L, Some(1150000)) + ) + + initHandshakerWithoutResolver.nextMessage.map(_.messageToSend) shouldBe Right(localHello: HelloEnc) + + val newRemoteStatusMsg = + remoteStatusMsg + .copy( + forkId = ForkId(1, None) // ForkId that is incompatible with our chain + ) + + val handshakerAfterHelloOpt = initHandshakerWithoutResolver.applyMessage(remoteHello) + assert(handshakerAfterHelloOpt.isDefined) + + handshakerAfterHelloOpt.get.nextMessage.map(_.messageToSend.underlyingMsg) shouldBe Right(newLocalStatusMsg) + + val handshakerAfterStatusOpt = handshakerAfterHelloOpt.get.applyMessage(newRemoteStatusMsg) + assert(handshakerAfterStatusOpt.isDefined) + + handshakerAfterStatusOpt.get.nextMessage match { + case Left(HandshakeFailure(Disconnect.Reasons.UselessPeer)) => succeed + case other => + fail(s"Invalid handshaker state: $other") + } + + } + it should "fail if a timeout happened during hello exchange" in new TestSetup { val handshakerAfterTimeout = initHandshakerWithoutResolver.processTimeout handshakerAfterTimeout.nextMessage.map(_.messageToSend) shouldBe Left( @@ -447,7 +487,7 @@ class EtcHandshakerSpec extends AnyFlatSpec with Matchers { totalDifficulty = 0, bestHash = genesisBlock.header.hash, genesisHash = genesisBlock.header.hash, - forkId = ForkId(2L, Some(3L)) + forkId = ForkId(0xfc64ec04L, Some(1150000)) ) val remoteStatus: RemoteStatus = RemoteStatus(remoteStatusMsg)