diff --git a/src/main/scala/io/iohk/ethereum/network/Peer.scala b/src/main/scala/io/iohk/ethereum/network/Peer.scala index 397cb08a15..b0f6ccc0cb 100644 --- a/src/main/scala/io/iohk/ethereum/network/Peer.scala +++ b/src/main/scala/io/iohk/ethereum/network/Peer.scala @@ -2,10 +2,14 @@ package io.iohk.ethereum.network import java.net.InetSocketAddress +import akka.NotUsed import akka.actor.ActorRef +import akka.pattern.Patterns.ask +import akka.stream.scaladsl.Source import akka.util.ByteString import io.iohk.ethereum.blockchain.sync.Blacklist.BlacklistId +import io.iohk.ethereum.network.p2p.Message final case class PeerId(value: String) extends BlacklistId @@ -18,6 +22,7 @@ final case class Peer( remoteAddress: InetSocketAddress, ref: ActorRef, incomingConnection: Boolean, + source: Source[Message, NotUsed] = Source.empty, nodeId: Option[ByteString] = None, createTimeMillis: Long = System.currentTimeMillis ) diff --git a/src/main/scala/io/iohk/ethereum/network/PeerActor.scala b/src/main/scala/io/iohk/ethereum/network/PeerActor.scala index f5a142bf77..38696541ae 100644 --- a/src/main/scala/io/iohk/ethereum/network/PeerActor.scala +++ b/src/main/scala/io/iohk/ethereum/network/PeerActor.scala @@ -3,8 +3,11 @@ package io.iohk.ethereum.network import java.net.InetSocketAddress import java.net.URI +import akka.NotUsed import akka.actor.SupervisorStrategy.Escalate import akka.actor._ +import akka.stream.OverflowStrategy +import akka.stream.scaladsl.Source import akka.util.ByteString import org.bouncycastle.util.encoders.Hex @@ -21,6 +24,7 @@ import io.iohk.ethereum.network.handshaker.Handshaker.HandshakeResult import io.iohk.ethereum.network.handshaker.Handshaker.NextMessage import io.iohk.ethereum.network.p2p._ import io.iohk.ethereum.network.p2p.messages.Capability +import io.iohk.ethereum.network.p2p.messages.Codes import io.iohk.ethereum.network.p2p.messages.WireProtocol._ import io.iohk.ethereum.network.rlpx.AuthHandshaker import io.iohk.ethereum.network.rlpx.RLPxConnectionHandler @@ -289,7 +293,17 @@ class PeerActor[R <: HandshakeResult]( class HandshakedPeer(remoteNodeId: ByteString, rlpxConnection: RLPxConnection, handshakeResult: R) { val peerId: PeerId = PeerId(Hex.toHexString(remoteNodeId.toArray)) - val peer: Peer = Peer(peerId, peerAddress, self, incomingConnection, Some(remoteNodeId)) + val source: Source[Message, NotUsed] = PeerEventBusActor + .messageSource( + peerEventBus, + PeerEventBusActor.SubscriptionClassifier + .MessageClassifier( + Set(Codes.BlockBodiesCode, Codes.BlockHeadersCode), + PeerEventBusActor.PeerSelector.WithId(peerId) + ) + ) + .map(_.message) + val peer: Peer = Peer(peerId, peerAddress, self, incomingConnection, source, Some(remoteNodeId)) peerEventBus ! Publish(PeerHandshakeSuccessful(peer, handshakeResult)) /** main behavior of actor that handles peer communication and subscriptions for messages diff --git a/src/main/scala/io/iohk/ethereum/network/PeerEventBusActor.scala b/src/main/scala/io/iohk/ethereum/network/PeerEventBusActor.scala index 05ab900d3b..b8a8414f06 100644 --- a/src/main/scala/io/iohk/ethereum/network/PeerEventBusActor.scala +++ b/src/main/scala/io/iohk/ethereum/network/PeerEventBusActor.scala @@ -1,9 +1,17 @@ package io.iohk.ethereum.network +import akka.NotUsed import akka.actor.Actor import akka.actor.ActorRef import akka.actor.Props +import akka.actor.Terminated import akka.event.ActorEventBus +import akka.stream.OverflowStrategy +import akka.stream.WatchedActorTerminatedException +import akka.stream.scaladsl.Source +import akka.util.Timeout + +import scala.concurrent.Future import io.iohk.ethereum.network.PeerEventBusActor.PeerEvent.MessageFromPeer import io.iohk.ethereum.network.PeerEventBusActor.PeerEvent.PeerDisconnected @@ -11,10 +19,36 @@ import io.iohk.ethereum.network.PeerEventBusActor.PeerEvent.PeerHandshakeSuccess import io.iohk.ethereum.network.PeerEventBusActor.SubscriptionClassifier._ import io.iohk.ethereum.network.handshaker.Handshaker.HandshakeResult import io.iohk.ethereum.network.p2p.Message +import io.iohk.ethereum.network.p2p.messages.Codes object PeerEventBusActor { def props: Props = Props(new PeerEventBusActor) + /** Handle subscription to the peer event bus via Akka Streams. + * + * @param peerEventBus ref to PeerEventBusActor + * @param messageClassifier specify which messages to subscribe to + * @return Source that subscribes to the peer event bus on materialization + * and unsubscribes on cancellation. It will complete when the event bus + * actor terminates. + * + * Note: + * - subscription is asynchronous so it may miss messages when starting. + * - it does not complete when a specified peerId disconnects. + */ + def messageSource(peerEventBus: ActorRef, messageClassifier: MessageClassifier): Source[MessageFromPeer, NotUsed] = + Source + .fromMaterializer { (mat, _) => + val (actorRef, src) = Source + .actorRef[MessageFromPeer](PartialFunction.empty, PartialFunction.empty, 1, OverflowStrategy.fail) + .watch(peerEventBus) + .preMaterialize()(mat) + peerEventBus + .tell(Subscribe(messageClassifier), actorRef) + src + } + .mapMaterializedValue(_ => NotUsed) + sealed trait PeerSelector { def contains(peerId: PeerId): Boolean } @@ -28,7 +62,6 @@ object PeerEventBusActor { case class WithId(peerId: PeerId) extends PeerSelector { override def contains(p: PeerId): Boolean = p == peerId } - } sealed trait SubscriptionClassifier @@ -196,20 +229,28 @@ object PeerEventBusActor { case class Unsubscribe(from: Option[SubscriptionClassifier] = None) case class Publish(ev: PeerEvent) - } class PeerEventBusActor extends Actor { - import PeerEventBusActor._ val peerEventBus: PeerEventBus = new PeerEventBus override def receive: Receive = { - case Subscribe(to) => peerEventBus.subscribe(sender(), to) - case Unsubscribe(Some(from)) => peerEventBus.unsubscribe(sender(), from) - case Unsubscribe(None) => peerEventBus.unsubscribe(sender()) - case Publish(ev: PeerEvent) => peerEventBus.publish(ev) - } + case Subscribe(to) => + peerEventBus.subscribe(sender(), to) + context.watch(sender()) + + case Unsubscribe(Some(from)) => + peerEventBus.unsubscribe(sender(), from) + + case Unsubscribe(None) => + peerEventBus.unsubscribe(sender()) + case Publish(ev: PeerEvent) => + peerEventBus.publish(ev) + + case Terminated(ref) => + peerEventBus.unsubscribe(ref) + } } diff --git a/src/main/scala/io/iohk/ethereum/network/PeerManagerActor.scala b/src/main/scala/io/iohk/ethereum/network/PeerManagerActor.scala index 71a6886fff..28c7cda61f 100644 --- a/src/main/scala/io/iohk/ethereum/network/PeerManagerActor.scala +++ b/src/main/scala/io/iohk/ethereum/network/PeerManagerActor.scala @@ -6,6 +6,7 @@ import java.util.Collections.newSetFromMap import akka.actor.SupervisorStrategy.Stop import akka.actor._ +import akka.stream.scaladsl.Source import akka.util.ByteString import akka.util.Timeout @@ -323,9 +324,7 @@ class PeerManagerActor( PeerId.fromRef(ref), address, ref, - incomingConnection, - nodeId = None, - createTimeMillis = System.currentTimeMillis + incomingConnection ) val newConnectedPeers = connectedPeers.addNewPendingPeer(pendingPeer) diff --git a/src/test/scala/io/iohk/ethereum/blockchain/sync/fast/FastSyncBranchResolverSpec.scala b/src/test/scala/io/iohk/ethereum/blockchain/sync/fast/FastSyncBranchResolverSpec.scala index 37b3cb2f52..2c7ff49417 100644 --- a/src/test/scala/io/iohk/ethereum/blockchain/sync/fast/FastSyncBranchResolverSpec.scala +++ b/src/test/scala/io/iohk/ethereum/blockchain/sync/fast/FastSyncBranchResolverSpec.scala @@ -181,7 +181,8 @@ class FastSyncBranchResolverSpec extends AnyWordSpec with Matchers with MockFact val blocksSavedInPeer: List[Block] = commonBlocks :++ BlockHelpers.generateChain(ourBestBlock + 1 - highestCommonBlock, commonBlocks.last) - val dummyPeer = Peer(PeerId("dummyPeer"), new InetSocketAddress("foo", 1), ActorRef.noSender, false, None, 0) + val dummyPeer = + Peer(PeerId("dummyPeer"), new InetSocketAddress("foo", 1), ActorRef.noSender, false, createTimeMillis = 0) val initialSearchState = SearchState(1, 10, dummyPeer) val ours = blocksSaved.map(b => (b.number, b)).toMap @@ -256,7 +257,8 @@ class FastSyncBranchResolverSpec extends AnyWordSpec with Matchers with MockFact val blocksSaved: List[Block] = BlockHelpers.generateChain(8, BlockHelpers.genesis) val blocksSavedInPeer: List[Block] = BlockHelpers.generateChain(8, BlockHelpers.genesis) - val dummyPeer = Peer(PeerId("dummyPeer"), new InetSocketAddress("foo", 1), ActorRef.noSender, false, None, 0) + val dummyPeer = + Peer(PeerId("dummyPeer"), new InetSocketAddress("foo", 1), ActorRef.noSender, false, createTimeMillis = 0) val initialSearchState = SearchState(1, 8, dummyPeer) val ours = blocksSaved.map(b => (b.number, b)).toMap diff --git a/src/test/scala/io/iohk/ethereum/network/EtcPeerManagerSpec.scala b/src/test/scala/io/iohk/ethereum/network/EtcPeerManagerSpec.scala index ae0714a794..f24cef1018 100644 --- a/src/test/scala/io/iohk/ethereum/network/EtcPeerManagerSpec.scala +++ b/src/test/scala/io/iohk/ethereum/network/EtcPeerManagerSpec.scala @@ -328,20 +328,20 @@ class EtcPeerManagerSpec extends AnyFlatSpec with Matchers { val peer1Probe: TestProbe = TestProbe() val peer1: Peer = - Peer(PeerId("peer1"), new InetSocketAddress("127.0.0.1", 1), peer1Probe.ref, false, Some(fakeNodeId)) + Peer(PeerId("peer1"), new InetSocketAddress("127.0.0.1", 1), peer1Probe.ref, false, nodeId = Some(fakeNodeId)) val peer1Info: PeerInfo = initialPeerInfo.withForkAccepted(false) val peer1InfoETC64: PeerInfo = initialPeerInfoETC64.withForkAccepted(false) val peer2Probe: TestProbe = TestProbe() val peer2: Peer = - Peer(PeerId("peer2"), new InetSocketAddress("127.0.0.1", 2), peer2Probe.ref, false, Some(fakeNodeId)) + Peer(PeerId("peer2"), new InetSocketAddress("127.0.0.1", 2), peer2Probe.ref, false, nodeId = Some(fakeNodeId)) val peer2Info: PeerInfo = initialPeerInfo.withForkAccepted(false) val peer3Probe: TestProbe = TestProbe() val peer3: Peer = - Peer(PeerId("peer3"), new InetSocketAddress("127.0.0.1", 3), peer3Probe.ref, false, Some(fakeNodeId)) + Peer(PeerId("peer3"), new InetSocketAddress("127.0.0.1", 3), peer3Probe.ref, false, nodeId = Some(fakeNodeId)) val freshPeerProbe: TestProbe = TestProbe() val freshPeer: Peer = - Peer(PeerId(""), new InetSocketAddress("127.0.0.1", 4), freshPeerProbe.ref, false, Some(fakeNodeId)) + Peer(PeerId(""), new InetSocketAddress("127.0.0.1", 4), freshPeerProbe.ref, false, nodeId = Some(fakeNodeId)) val freshPeerInfo: PeerInfo = initialPeerInfo.withForkAccepted(false) val peerManager: TestProbe = TestProbe() diff --git a/src/test/scala/io/iohk/ethereum/network/PeerEventBusActorSpec.scala b/src/test/scala/io/iohk/ethereum/network/PeerEventBusActorSpec.scala index 5ab613aeec..b75beda5dc 100644 --- a/src/test/scala/io/iohk/ethereum/network/PeerEventBusActorSpec.scala +++ b/src/test/scala/io/iohk/ethereum/network/PeerEventBusActorSpec.scala @@ -4,13 +4,22 @@ import java.net.InetSocketAddress import akka.actor.ActorRef import akka.actor.ActorSystem +import akka.actor.PoisonPill +import akka.stream.WatchedActorTerminatedException +import akka.stream.scaladsl.Flow +import akka.stream.scaladsl.Keep +import akka.stream.scaladsl.Sink +import akka.stream.scaladsl.Source +import akka.testkit.TestActor import akka.testkit.TestProbe import akka.util.ByteString +import org.scalatest.concurrent.ScalaFutures import org.scalatest.flatspec.AnyFlatSpec import org.scalatest.matchers.should.Matchers import io.iohk.ethereum.Fixtures +import io.iohk.ethereum.NormalPatience import io.iohk.ethereum.domain.ChainWeight import io.iohk.ethereum.network.EtcPeerManagerActor.PeerInfo import io.iohk.ethereum.network.EtcPeerManagerActor.RemoteStatus @@ -19,11 +28,12 @@ import io.iohk.ethereum.network.PeerEventBusActor.PeerEvent.PeerDisconnected import io.iohk.ethereum.network.PeerEventBusActor.PeerEvent.PeerHandshakeSuccessful import io.iohk.ethereum.network.PeerEventBusActor.PeerSelector import io.iohk.ethereum.network.PeerEventBusActor.SubscriptionClassifier._ +import io.iohk.ethereum.network.p2p.Message import io.iohk.ethereum.network.p2p.messages.Capability import io.iohk.ethereum.network.p2p.messages.WireProtocol.Ping import io.iohk.ethereum.network.p2p.messages.WireProtocol.Pong -class PeerEventBusActorSpec extends AnyFlatSpec with Matchers { +class PeerEventBusActorSpec extends AnyFlatSpec with Matchers with ScalaFutures with NormalPatience { "PeerEventBusActor" should "relay messages received to subscribers" in new TestSetup { @@ -32,6 +42,7 @@ class PeerEventBusActorSpec extends AnyFlatSpec with Matchers { val classifier1 = MessageClassifier(Set(Ping.code), PeerSelector.WithId(PeerId("1"))) val classifier2 = MessageClassifier(Set(Ping.code), PeerSelector.AllPeers) peerEventBusActor.tell(PeerEventBusActor.Subscribe(classifier1), probe1.ref) + peerEventBusActor.tell(PeerEventBusActor.Subscribe(classifier2), probe2.ref) val msgFromPeer = MessageFromPeer(Ping(), PeerId("1")) @@ -46,6 +57,47 @@ class PeerEventBusActorSpec extends AnyFlatSpec with Matchers { peerEventBusActor ! PeerEventBusActor.Publish(msgFromPeer2) probe1.expectNoMessage() probe2.expectMsg(msgFromPeer2) + + } + + it should "relay messages via streams" in new TestSetup { + val classifier1 = MessageClassifier(Set(Ping.code), PeerSelector.WithId(PeerId("1"))) + val classifier2 = MessageClassifier(Set(Ping.code), PeerSelector.AllPeers) + + val peerEventBusProbe = TestProbe()(system) + peerEventBusProbe.setAutoPilot { (sender: ActorRef, msg: Any) => + peerEventBusActor.tell(msg, sender) + TestActor.KeepRunning + } + + val seqOnTermination = Flow[MessageFromPeer] + .recoverWithRetries(1, { case _: WatchedActorTerminatedException => Source.empty }) + .toMat(Sink.seq)(Keep.right) + + val stream1 = PeerEventBusActor.messageSource(peerEventBusProbe.ref, classifier1).runWith(seqOnTermination) + val stream2 = PeerEventBusActor.messageSource(peerEventBusProbe.ref, classifier2).runWith(seqOnTermination) + + // wait for subscriptions to be done + peerEventBusProbe.expectMsgType[PeerEventBusActor.Subscribe] + peerEventBusProbe.expectMsgType[PeerEventBusActor.Subscribe] + + val syncProbe = TestProbe()(system) + peerEventBusActor.tell(PeerEventBusActor.Subscribe(classifier2), syncProbe.ref) + + val msgFromPeer = MessageFromPeer(Ping(), PeerId("1")) + peerEventBusActor ! PeerEventBusActor.Publish(msgFromPeer) + + val msgFromPeer2 = MessageFromPeer(Ping(), PeerId("99")) + peerEventBusActor ! PeerEventBusActor.Publish(msgFromPeer2) + + // wait for publications to be done + syncProbe.expectMsg(msgFromPeer) + syncProbe.expectMsg(msgFromPeer2) + + peerEventBusProbe.ref ! PoisonPill + + whenReady(stream1)(_ shouldEqual Seq(msgFromPeer)) + whenReady(stream2)(_ shouldEqual Seq(msgFromPeer, msgFromPeer2)) } it should "only relay matching message codes" in new TestSetup { @@ -105,7 +157,13 @@ class PeerEventBusActorSpec extends AnyFlatSpec with Matchers { peerEventBusActor.tell(PeerEventBusActor.Subscribe(PeerHandshaked), probe2.ref) val peerHandshaked = - new Peer(PeerId("peer1"), new InetSocketAddress("127.0.0.1", 0), TestProbe().ref, false, Some(ByteString())) + new Peer( + PeerId("peer1"), + new InetSocketAddress("127.0.0.1", 0), + TestProbe().ref, + false, + nodeId = Some(ByteString()) + ) val msgPeerHandshaked = PeerHandshakeSuccessful(peerHandshaked, initialPeerInfo) peerEventBusActor ! PeerEventBusActor.Publish(msgPeerHandshaked) diff --git a/src/test/scala/io/iohk/ethereum/network/PeerManagerSpec.scala b/src/test/scala/io/iohk/ethereum/network/PeerManagerSpec.scala index 7c264b851a..b2a44091a1 100644 --- a/src/test/scala/io/iohk/ethereum/network/PeerManagerSpec.scala +++ b/src/test/scala/io/iohk/ethereum/network/PeerManagerSpec.scala @@ -195,7 +195,8 @@ class PeerManagerSpec // It should have created the next peer for the first incoming connection (probably using a synchronous test scheduler). val probe2: TestProbe = createdPeers(2).probe - val peer = Peer(PeerId("peer"), incomingPeerAddress1, probe2.ref, incomingConnection = true, Some(incomingNodeId1)) + val peer = + Peer(PeerId("peer"), incomingPeerAddress1, probe2.ref, incomingConnection = true, nodeId = Some(incomingNodeId1)) probe2.expectMsg(PeerActor.HandleConnection(incomingConnection1.ref, incomingPeerAddress1)) probe2.reply(PeerEvent.PeerHandshakeSuccessful(peer, initialPeerInfo)) @@ -213,7 +214,13 @@ class PeerManagerSpec val probe3: TestProbe = createdPeers(3).probe val secondPeer = - Peer(PeerId("secondPeer"), incomingPeerAddress2, probe3.ref, incomingConnection = true, Some(incomingNodeId2)) + Peer( + PeerId("secondPeer"), + incomingPeerAddress2, + probe3.ref, + incomingConnection = true, + nodeId = Some(incomingNodeId2) + ) probe3.expectMsg(PeerActor.HandleConnection(incomingConnection2.ref, incomingPeerAddress2)) probe3.reply(PeerEvent.PeerHandshakeSuccessful(secondPeer, initialPeerInfo)) @@ -287,7 +294,7 @@ class PeerManagerSpec peerAsIncomingAddress, peerAsIncomingProbe.ref, incomingConnection = true, - Some(nodeId) + nodeId = Some(nodeId) ) peerAsIncomingProbe.expectMsg( @@ -322,7 +329,7 @@ class PeerManagerSpec peerAsIncomingAddress, peerAsIncomingProbe.ref, incomingConnection = true, - Some(nodeId) + nodeId = Some(nodeId) ) peerAsIncomingProbe.expectMsg(