diff --git a/scalanet/discovery/it/src/io/iohk/scalanet/discovery/ethereum/v4/DiscoveryKademliaIntegrationSpec.scala b/scalanet/discovery/it/src/io/iohk/scalanet/discovery/ethereum/v4/DiscoveryKademliaIntegrationSpec.scala index 99e9112a..14a39e2c 100644 --- a/scalanet/discovery/it/src/io/iohk/scalanet/discovery/ethereum/v4/DiscoveryKademliaIntegrationSpec.scala +++ b/scalanet/discovery/it/src/io/iohk/scalanet/discovery/ethereum/v4/DiscoveryKademliaIntegrationSpec.scala @@ -62,7 +62,8 @@ class DiscoveryKademliaIntegrationSpec extends KademliaIntegrationSpec("Discover kademliaAlpha = testConfig.alpha, kademliaBucketSize = testConfig.k, discoveryPeriod = testConfig.refreshRate, - knownPeers = initialNodes + knownPeers = initialNodes, + subnetLimitPrefixLength = 0 ) network <- Resource.liftF { DiscoveryNetwork[InetMultiAddress]( diff --git a/scalanet/discovery/src/io/iohk/scalanet/discovery/ethereum/v4/DiscoveryConfig.scala b/scalanet/discovery/src/io/iohk/scalanet/discovery/ethereum/v4/DiscoveryConfig.scala index 0c0e4079..0180e4df 100644 --- a/scalanet/discovery/src/io/iohk/scalanet/discovery/ethereum/v4/DiscoveryConfig.scala +++ b/scalanet/discovery/src/io/iohk/scalanet/discovery/ethereum/v4/DiscoveryConfig.scala @@ -23,7 +23,13 @@ case class DiscoveryConfig( // How often to look for new peers. discoveryPeriod: FiniteDuration, // Bootstrap nodes. - knownPeers: Set[Node] + knownPeers: Set[Node], + // Limit the number of IPs from the same subnet, given by its prefix length, e.g. /24; 0 means no limit. + subnetLimitPrefixLength: Int, + // Limit the number of IPs from the same subnet in any given bucket; 0 means no limit. + subnetLimitForBucket: Int, + // Limit the number of IPs from the same subnet in the whole k-table; 0 means no limit. + subnetLimitForTable: Int ) object DiscoveryConfig { @@ -36,6 +42,9 @@ object DiscoveryConfig { kademliaAlpha = 3, bondExpiration = 12.hours, discoveryPeriod = 15.minutes, - knownPeers = Set.empty + knownPeers = Set.empty, + subnetLimitPrefixLength = 24, + subnetLimitForBucket = 2, + subnetLimitForTable = 10 ) } diff --git a/scalanet/discovery/src/io/iohk/scalanet/discovery/ethereum/v4/DiscoveryService.scala b/scalanet/discovery/src/io/iohk/scalanet/discovery/ethereum/v4/DiscoveryService.scala index 5ab85254..889051e0 100644 --- a/scalanet/discovery/src/io/iohk/scalanet/discovery/ethereum/v4/DiscoveryService.scala +++ b/scalanet/discovery/src/io/iohk/scalanet/discovery/ethereum/v4/DiscoveryService.scala @@ -6,7 +6,7 @@ import cats.implicits._ import io.iohk.scalanet.discovery.crypto.{PrivateKey, SigAlg} import io.iohk.scalanet.discovery.ethereum.{Node, EthereumNodeRecord} import io.iohk.scalanet.discovery.hash.Hash -import io.iohk.scalanet.kademlia.{KBuckets, XorOrdering} +import io.iohk.scalanet.kademlia.XorOrdering import io.iohk.scalanet.peergroup.Addressable import java.net.InetAddress import monix.eval.Task @@ -53,6 +53,7 @@ trait DiscoveryService { object DiscoveryService { import DiscoveryRPC.{Call, Proc} import DiscoveryNetwork.Peer + import KBucketsWithSubnetLimits.SubnetLimits type ENRSeq = Long type Timestamp = Long @@ -90,7 +91,7 @@ object DiscoveryService { // Use the current time to set the ENR sequence to something fresh. now <- clock.monotonic(MILLISECONDS) enr <- Task(EthereumNodeRecord.fromNode(node, privateKey, seq = now).require) - stateRef <- Ref[Task].of(State[A](node, enr)) + stateRef <- Ref[Task].of(State[A](node, enr, SubnetLimits.fromConfig(config))) service <- Task(new ServiceImpl[A](privateKey, config, network, stateRef, toAddress)) // Start handling requests, we need them during enrolling so the peers can ping and bond with us. cancelToken <- network.startHandling(service) @@ -144,7 +145,7 @@ object DiscoveryService { node: Node, enr: EthereumNodeRecord, // Kademlia buckets with hashes of the nodes' IDs in them. - kBuckets: KBuckets[Hash], + kBuckets: KBucketsWithSubnetLimits[A], kademliaIdToNodeId: Map[Hash, Node.Id], nodeMap: Map[Node.Id, Node], enrMap: Map[Node.Id, EthereumNodeRecord], @@ -178,10 +179,10 @@ object DiscoveryService { kBuckets = if (isSelf(peer)) kBuckets - else if (kBuckets.getBucket(peer.kademliaId)._2.contains(peer.kademliaId)) - kBuckets.touch(peer.kademliaId) + else if (kBuckets.contains(peer)) + kBuckets.touch(peer) else if (addToBucket) - kBuckets.add(peer.kademliaId) + kBuckets.add(peer) else kBuckets, kademliaIdToNodeId = kademliaIdToNodeId.updated(peer.kademliaId, peer.id) @@ -190,8 +191,8 @@ object DiscoveryService { /** Update the timestamp of the peer in the K-table, if it's still part of it. */ def withTouch(peer: Peer[A]): State[A] = - if (kBuckets.contains(peer.kademliaId)) - copy(kBuckets = kBuckets.touch(peer.kademliaId)) + if (kBuckets.contains(peer)) + copy(kBuckets = kBuckets.touch(peer)) else // Not adding because `kademliaIdToNodeId` and `nodeMap` may no longer have this peer. this @@ -217,7 +218,7 @@ object DiscoveryService { copy( nodeMap = nodeMap - peer.id, enrMap = enrMap - peer.id, - kBuckets = kBuckets.remove(peer.kademliaId), + kBuckets = kBuckets.remove(peer), kademliaIdToNodeId = kademliaIdToNodeId - peer.kademliaId ) case _ => this @@ -228,32 +229,36 @@ object DiscoveryService { ) } - def removePeer(peerId: Node.Id): State[A] = + def removePeer(peerId: Node.Id, toAddress: Node.Address => A): State[A] = { + // Find any Peer records that correspond to this ID. + val peers: Set[Peer[A]] = ( + nodeMap.get(peerId).map(node => Peer(node.id, toAddress(node.address))).toSeq ++ + lastPongTimestampMap.keys.filter(_.id == peerId).toSeq ++ + bondingResultsMap.keys.filter(_.id == peerId).toSeq + ).toSet + copy( nodeMap = nodeMap - peerId, enrMap = enrMap - peerId, - lastPongTimestampMap = lastPongTimestampMap.filterNot { - case (peer, _) => peer.id == peerId - }, - bondingResultsMap = bondingResultsMap.filterNot { - case (peer, _) => peer.id == peerId - }, - kBuckets = kBuckets.remove(Node.kademliaId(peerId)), + lastPongTimestampMap = lastPongTimestampMap -- peers, + bondingResultsMap = bondingResultsMap -- peers, + kBuckets = peers.foldLeft(kBuckets)(_ remove _), kademliaIdToNodeId = kademliaIdToNodeId - Node.kademliaId(peerId) ) + } def setEnrolled: State[A] = copy(hasEnrolled = true) } protected[v4] object State { - def apply[A]( + def apply[A: Addressable]( node: Node, enr: EthereumNodeRecord, - clock: java.time.Clock = java.time.Clock.systemUTC() + subnetLimits: SubnetLimits ): State[A] = State[A]( node = node, enr = enr, - kBuckets = new KBuckets[Hash](node.kademliaId, clock), + kBuckets = KBucketsWithSubnetLimits[A](node, subnetLimits), kademliaIdToNodeId = Map(node.kademliaId -> node.id), nodeMap = Map(node.id -> node), enrMap = Map(node.id -> enr), @@ -306,7 +311,7 @@ object DiscoveryService { override def removeNode(nodeId: Node.Id): Task[Unit] = stateRef.update { state => - if (state.node.id == nodeId) state else state.removePeer(nodeId) + if (state.node.id == nodeId) state else state.removePeer(nodeId, toAddress) } /** Update the node and ENR of the local peer with the new address and ping peers with the new ENR seq. */ @@ -673,7 +678,7 @@ object DiscoveryService { if (state.isSelf(peer)) state -> None else { - val (_, bucket) = state.kBuckets.getBucket(peer.kademliaId) + val (_, bucket) = state.kBuckets.getBucket(peer) val (addToBucket, maybeEvict) = if (bucket.contains(peer.kademliaId) || bucket.size < config.kademliaBucketSize) { // We can just update the records, the bucket either has room or won't need to grow. diff --git a/scalanet/discovery/src/io/iohk/scalanet/discovery/ethereum/v4/KBucketsWithSubnetLimits.scala b/scalanet/discovery/src/io/iohk/scalanet/discovery/ethereum/v4/KBucketsWithSubnetLimits.scala new file mode 100644 index 00000000..4e80aa02 --- /dev/null +++ b/scalanet/discovery/src/io/iohk/scalanet/discovery/ethereum/v4/KBucketsWithSubnetLimits.scala @@ -0,0 +1,159 @@ +package io.iohk.scalanet.discovery.ethereum.v4 + +import cats._ +import cats.implicits._ +import io.iohk.scalanet.discovery.hash.Hash +import io.iohk.scalanet.discovery.ethereum.Node +import io.iohk.scalanet.kademlia.{KBuckets, TimeSet} +import io.iohk.scalanet.peergroup.Addressable +import io.iohk.scalanet.peergroup.InetAddressOps._ +import java.net.InetAddress + +case class KBucketsWithSubnetLimits[A: Addressable]( + table: KBuckets[Hash], + limits: KBucketsWithSubnetLimits.SubnetLimits, + tableLevelCounts: KBucketsWithSubnetLimits.TableLevelCounts, + bucketLevelCounts: KBucketsWithSubnetLimits.BucketLevelCounts +) { + import DiscoveryNetwork.Peer + import KBucketsWithSubnetLimits._ + + def contains(peer: Peer[A]): Boolean = + table.contains(peer.kademliaId) + + def touch(peer: Peer[A]): KBucketsWithSubnetLimits[A] = + // Note that `KBuckets.touch` also adds, so if the the record + // isn't in the table already then use `add` to maintain counts. + if (contains(peer)) copy(table = table.touch(peer.kademliaId)) else add(peer) + + /** Add the peer to the underlying K-table unless doing so would violate some limit. */ + def add(peer: Peer[A]): KBucketsWithSubnetLimits[A] = + if (contains(peer)) this + else { + val ip = subnet(peer) + val idx = getBucket(peer)._1 + + // Upsert the counts of the index and/or IP in the maps, so that we can check the limits on them. + val tlc = incrementForTable(ip) + val blc = incrementForBucket(idx, ip) + + val isOverAnyLimit = + limits.isOverLimitForTable(tlc(ip)) || + limits.isOverLimitForBucket(blc(idx)(ip)) + + if (isOverAnyLimit) this + else { + copy( + table = table.add(peer.kademliaId), + tableLevelCounts = tlc, + bucketLevelCounts = blc + ) + } + } + + def remove(peer: Peer[A]): KBucketsWithSubnetLimits[A] = + if (!contains(peer)) this + else { + val ip = subnet(peer) + val idx = getBucket(peer)._1 + + val tlc = decrementForTable(ip) + val blc = decrementForBucket(idx, ip) + + copy(table = table.remove(peer.kademliaId), tableLevelCounts = tlc, bucketLevelCounts = blc) + } + + def closestNodes(targetKademliaId: Hash, n: Int): List[Hash] = + table.closestNodes(targetKademliaId, n) + + def getBucket(peer: Peer[A]): (Int, TimeSet[Hash]) = + table.getBucket(peer.kademliaId) + + private def subnet(peer: Peer[A]): InetAddress = + Addressable[A].getAddress(peer.address).getAddress.truncate(limits.prefixLength) + + /** Increase the table level count for the IP of a subnet. */ + private def incrementForTable(ip: InetAddress): TableLevelCounts = + tableLevelCounts |+| Map(ip -> 1) + + /** Increase the bucket level count for the IP of a subnet. */ + private def incrementForBucket(idx: Int, ip: InetAddress): BucketLevelCounts = + bucketLevelCounts |+| Map(idx -> Map(ip -> 1)) + + /** Decrement the table level count for the IP of a subnet and remove the entry if it's zero. */ + private def decrementForTable(ip: InetAddress): TableLevelCounts = + tableLevelCounts |+| Map(ip -> -1) match { + case counts if counts(ip) <= 0 => counts - ip + case counts => counts + } + + /** Decrement the bucket level count for the IP of a subnet and remove the entry if it's zero + * for the subnet itself, or the whole bucket. + */ + private def decrementForBucket(idx: Int, ip: InetAddress): BucketLevelCounts = + bucketLevelCounts |+| Map(idx -> Map(ip -> -1)) match { + case counts if counts(idx)(ip) <= 0 && counts(idx).size > 1 => + // The subnet count in the bucket is zero, but there are other subnets in the bucket, + // so keep the bucket level count and just remove the subnet from it. + counts.updated(idx, counts(idx) - ip) + case counts if counts(idx)(ip) <= 0 => + // The subnet count is zero, and it's the only subnet in the bucket, so remove the bucket. + counts - idx + case counts => + counts + } +} + +object KBucketsWithSubnetLimits { + type SubnetCounts = Map[InetAddress, Int] + type TableLevelCounts = SubnetCounts + type BucketLevelCounts = Map[Int, SubnetCounts] + + case class SubnetLimits( + // Number of leftmost bits of the IP address that counts as a subnet, serving as its ID. + prefixLength: Int, + // Limit of nodes from the same subnet within any given bucket in the K-table. + forBucket: Int, + // Limit of nodes from the same subnet across all buckets in the K-table. + forTable: Int + ) { + + /** All limits can be disabled by setting the subnet prefix length to 0. */ + def isEnabled: Boolean = prefixLength > 0 + + def isEnabledForBucket: Boolean = + isEnabled && forBucket > 0 + + def isEnabledForTable: Boolean = + isEnabled && forTable > 0 + + def isOverLimitForBucket(count: Int): Boolean = + isEnabledForBucket && count > forBucket + + def isOverLimitForTable(count: Int): Boolean = + isEnabledForTable && count > forTable + } + + object SubnetLimits { + val Unlimited = SubnetLimits(0, 0, 0) + + def fromConfig(config: DiscoveryConfig): SubnetLimits = + SubnetLimits( + prefixLength = config.subnetLimitPrefixLength, + forBucket = config.subnetLimitForBucket, + forTable = config.subnetLimitForTable + ) + } + + def apply[A: Addressable]( + node: Node, + limits: SubnetLimits + ): KBucketsWithSubnetLimits[A] = { + KBucketsWithSubnetLimits[A]( + new KBuckets[Hash](node.kademliaId, clock = java.time.Clock.systemUTC()), + limits, + tableLevelCounts = Map.empty[InetAddress, Int], + bucketLevelCounts = Map.empty[Int, Map[InetAddress, Int]] + ) + } +} diff --git a/scalanet/discovery/ut/src/io/iohk/scalanet/discovery/ethereum/v4/DiscoveryServiceSpec.scala b/scalanet/discovery/ut/src/io/iohk/scalanet/discovery/ethereum/v4/DiscoveryServiceSpec.scala index 9fc7675d..674ae81e 100644 --- a/scalanet/discovery/ut/src/io/iohk/scalanet/discovery/ethereum/v4/DiscoveryServiceSpec.scala +++ b/scalanet/discovery/ut/src/io/iohk/scalanet/discovery/ethereum/v4/DiscoveryServiceSpec.scala @@ -17,6 +17,7 @@ import org.scalatest._ import scala.concurrent.duration._ import scala.util.Random import java.net.InetAddress +import io.iohk.scalanet.discovery.ethereum.v4.KBucketsWithSubnetLimits.SubnetLimits class DiscoveryServiceSpec extends AsyncFlatSpec with Matchers { import DiscoveryService.{State, BondingResults} @@ -411,7 +412,7 @@ class DiscoveryServiceSpec extends AsyncFlatSpec with Matchers { state.fetchEnrMap should not contain key(remotePeer) state.nodeMap should contain key (remotePeer.id) state.enrMap(remotePeer.id) shouldBe remoteENR - state.kBuckets.contains(remotePeer.kademliaId) shouldBe true + state.kBuckets.contains(remotePeer) shouldBe true } } } @@ -488,13 +489,13 @@ class DiscoveryServiceSpec extends AsyncFlatSpec with Matchers { // If the existing peer didn't respond, forget them completely. state.nodeMap.contains(peer1._1.id) shouldBe responds state.enrMap.contains(peer1._1.id) shouldBe responds - state.kBuckets.contains(peer1._1.kademliaId) shouldBe responds + state.kBuckets.contains(peer1._1) shouldBe responds // Add the new ENR of the peer regardless of the existing. state.nodeMap.contains(peer2._1.id) shouldBe true state.enrMap.contains(peer2._1.id) shouldBe true // Only use them for routing if the existing got evicted. - state.kBuckets.contains(peer2._1.kademliaId) shouldBe !responds + state.kBuckets.contains(peer2._1) shouldBe !responds } } @@ -936,7 +937,7 @@ class DiscoveryServiceSpec extends AsyncFlatSpec with Matchers { state.enrMap should not contain key(remotePublicKey) state.nodeMap should not contain key(remotePublicKey) state.lastPongTimestampMap should not contain key(remotePeer) - state.kBuckets.contains(remoteNode.kademliaId) shouldBe false + state.kBuckets.contains(remotePeer) shouldBe false } } } @@ -949,7 +950,7 @@ class DiscoveryServiceSpec extends AsyncFlatSpec with Matchers { } yield { state.enrMap should contain key (localPublicKey) state.nodeMap should contain key (localPublicKey) - state.kBuckets.contains(localNode.kademliaId) shouldBe true + state.kBuckets.contains(localPeer) shouldBe true } } } @@ -1015,14 +1016,12 @@ class DiscoveryServiceSpec extends AsyncFlatSpec with Matchers { _ <- Task.sleep(1.milli) // If we're too quick then TimeSet will assign the same timestamp. state3 = state2.withTouch(remotePeer) } yield { - val peerId = remotePeer.kademliaId + state0.kBuckets.contains(remotePeer) shouldBe false + state1.kBuckets.contains(remotePeer) shouldBe false + state2.kBuckets.contains(remotePeer) shouldBe true - state0.kBuckets.contains(peerId) shouldBe false - state1.kBuckets.contains(peerId) shouldBe false - state2.kBuckets.contains(peerId) shouldBe true - - val (_, bucket2) = state2.kBuckets.getBucket(peerId) - val (_, bucket3) = state3.kBuckets.getBucket(peerId) + val (_, bucket2) = state2.kBuckets.getBucket(remotePeer) + val (_, bucket3) = state3.kBuckets.getBucket(remotePeer) bucket2.timestamps should not equal bucket3.timestamps } } @@ -1048,7 +1047,8 @@ object DiscoveryServiceSpec { val unimplementedRPC = StubDiscoveryRPC() val defaultConfig = DiscoveryConfig.default.copy( - requestTimeout = 50.millis + requestTimeout = 100.millis, + subnetLimitPrefixLength = 0 ) trait Fixture { @@ -1070,12 +1070,10 @@ object DiscoveryServiceSpec { lazy val remoteENR = EthereumNodeRecord.fromNode(remoteNode, remotePrivateKey, seq = 1).require lazy val stateRef = Ref.unsafe[Task, DiscoveryService.State[InetSocketAddress]]( - DiscoveryService.State[InetSocketAddress](localNode, localENR) + DiscoveryService.State[InetSocketAddress](localNode, localENR, SubnetLimits.fromConfig(config)) ) - lazy val config: DiscoveryConfig = defaultConfig.copy( - requestTimeout = 100.millis - ) + lazy val config: DiscoveryConfig = defaultConfig lazy val rpc = unimplementedRPC diff --git a/scalanet/discovery/ut/src/io/iohk/scalanet/discovery/ethereum/v4/KBucketsWithSubnetLimitsSpec.scala b/scalanet/discovery/ut/src/io/iohk/scalanet/discovery/ethereum/v4/KBucketsWithSubnetLimitsSpec.scala new file mode 100644 index 00000000..8f6a63fb --- /dev/null +++ b/scalanet/discovery/ut/src/io/iohk/scalanet/discovery/ethereum/v4/KBucketsWithSubnetLimitsSpec.scala @@ -0,0 +1,139 @@ +package io.iohk.scalanet.discovery.ethereum.v4 + +import org.scalatest._ +import java.net.InetSocketAddress +import io.iohk.scalanet.discovery.ethereum.v4.KBucketsWithSubnetLimits.SubnetLimits +import io.iohk.scalanet.discovery.ethereum.Node +import io.iohk.scalanet.discovery.hash.Keccak256 +import java.net.InetAddress +import io.iohk.scalanet.discovery.ethereum.v4.DiscoveryNetwork.Peer +import scodec.bits.BitVector +import io.iohk.scalanet.discovery.crypto.PublicKey + +class KBucketsWithSubnetLimitsSpec extends FlatSpec with Matchers with Inspectors { + + // For the tests I only care about the IP addresses; a 1-to-1 mapping is convenient. + def fakeNodeId(address: InetAddress): Node.Id = + PublicKey(Keccak256(BitVector(address.getAddress))) + + def makeNode(address: InetSocketAddress) = + Node(fakeNodeId(address.getAddress), Node.Address(address.getAddress, address.getPort, address.getPort)) + + def makePeer(address: InetAddress, port: Int = 30303) = + Peer[InetSocketAddress](id = fakeNodeId(address), address = new InetSocketAddress(address, port)) + + def makeIp(name: String) = InetAddress.getByName(name) + + val localNode = makeNode(new InetSocketAddress("127.0.0.1", 30303)) + val defaultLimits = SubnetLimits(prefixLength = 24, forBucket = 2, forTable = 10) + + trait Fixture { + lazy val limits = defaultLimits + lazy val ips: Vector[String] = Vector.empty + lazy val peers = ips.map(ip => makePeer(makeIp(ip))) + lazy val kBuckets = peers.foldLeft(KBucketsWithSubnetLimits(localNode, limits = limits))(_.add(_)) + } + + behavior of "KBucketsWithSubnetLimits" + + it should "increment the count of the subnet after add" in new Fixture { + override lazy val ips = Vector("5.67.8.9", "5.67.8.10", "5.67.1.2") + val subnet = makeIp("5.67.8.0") + val idx = kBuckets.getBucket(peers.head)._1 + kBuckets.tableLevelCounts(subnet) shouldBe 2 + kBuckets.tableLevelCounts.values should contain theSameElementsAs List(2, 1) + kBuckets.bucketLevelCounts(idx)(subnet) shouldBe >=(1) + } + + it should "not increment the count if the peer is already in the table" in new Fixture { + override lazy val ips = Vector("5.67.8.9", "5.67.8.9", "5.67.8.9") + val subnet = makeIp("5.67.8.0") + val idx = kBuckets.getBucket(peers.head)._1 + kBuckets.tableLevelCounts(subnet) shouldBe 1 + kBuckets.bucketLevelCounts(idx)(subnet) shouldBe 1 + } + + it should "decrement the count after removal" in new Fixture { + override lazy val ips = Vector("5.67.8.9", "5.67.8.10") + + val removed0 = kBuckets.remove(peers(0)) + removed0.tableLevelCounts.values.toList shouldBe List(1) + removed0.bucketLevelCounts.values.toList shouldBe List(Map(makeIp("5.67.8.0") -> 1)) + + val removed1 = removed0.remove(peers(1)) + removed1.tableLevelCounts shouldBe empty + removed1.bucketLevelCounts shouldBe empty + } + + it should "not decrement if the peer is not in the table" in new Fixture { + override lazy val ips = Vector("1.2.3.4") + val removed = kBuckets.remove(makePeer(makeIp("1.2.3.5"))) + kBuckets.tableLevelCounts should not be empty + kBuckets.bucketLevelCounts should not be empty + } + + it should "not add IP if it violates the limits" in new Fixture { + override lazy val ips = Vector.range(0, defaultLimits.forTable + 1).map(i => s"192.168.1.$i") + + forAll(peers.take(defaultLimits.forBucket)) { peer => + kBuckets.contains(peer) shouldBe true + } + + forAtLeast(1, peers) { peer => + kBuckets.contains(peer) shouldBe false + } + + forAll(peers) { peer => + val (_, bucket) = kBuckets.getBucket(peer) + bucket.size shouldBe <=(defaultLimits.forBucket) + } + } + + it should "treat limits separately per subnet" in new Fixture { + override lazy val ips = Vector.range(0, 256).map { i => + s"192.168.1.$i" + } :+ "192.168.2.1" + + kBuckets.contains(peers.last) shouldBe true + } + + it should "add peers after removing previous ones" in new Fixture { + override lazy val ips = Vector.range(0, 255).map(i => s"192.168.1.$i") + + kBuckets.tableLevelCounts.values.toList shouldBe List(defaultLimits.forTable) + + val peer = makePeer(makeIp("192.168.1.255")) + kBuckets.add(peer).contains(peer) shouldBe false + kBuckets.remove(peer).add(peer).contains(peer) shouldBe false + kBuckets.remove(peers.head).add(peer).contains(peer) shouldBe true + } + + it should "not use limits if the prefix is 0" in new Fixture { + override lazy val limits = defaultLimits.copy(prefixLength = 0) + override lazy val ips = Vector.range(0, 256).map(i => s"192.168.1.$i") + + kBuckets.tableLevelCounts.values.toList shouldBe List(256) + } + + it should "not use limits if the table level limit is 0, but still apply the bucket limit" in new Fixture { + override lazy val limits = defaultLimits.copy(forTable = 0) + override lazy val ips = Vector.range(0, 256).map(i => s"192.168.1.$i") + + kBuckets.tableLevelCounts.values.toList.head shouldBe >(defaultLimits.forTable) + forAll(peers) { peer => + val (i, _) = kBuckets.getBucket(peer) + kBuckets.bucketLevelCounts(i).values.head shouldBe <=(defaultLimits.forBucket) + } + } + + it should "not limit buckets if the bucket level limit is 0" in new Fixture { + override lazy val limits = defaultLimits.copy(forBucket = 0) + override lazy val ips = Vector.range(0, 256).map(i => s"192.168.1.$i") + + kBuckets.tableLevelCounts.values.toList shouldBe List(limits.forTable) + forAtLeast(1, peers) { peer => + val (i, _) = kBuckets.getBucket(peer) + kBuckets.bucketLevelCounts(i).values.head shouldBe >(defaultLimits.forBucket) + } + } +} diff --git a/scalanet/src/io/iohk/scalanet/peergroup/InetAddressOps.scala b/scalanet/src/io/iohk/scalanet/peergroup/InetAddressOps.scala index ce71eaa6..dcfe01b2 100644 --- a/scalanet/src/io/iohk/scalanet/peergroup/InetAddressOps.scala +++ b/scalanet/src/io/iohk/scalanet/peergroup/InetAddressOps.scala @@ -3,6 +3,7 @@ package io.iohk.scalanet.peergroup import com.github.jgonian.ipmath.{Ipv6Range, Ipv4Range, Ipv4, Ipv6} import java.net.{InetAddress, Inet4Address, Inet6Address} import scala.language.implicitConversions +import com.github.jgonian.ipmath.AbstractIp class InetAddressOps(val address: InetAddress) extends AnyVal { import InetAddressOps._ @@ -23,14 +24,30 @@ class InetAddressOps(val address: InetAddress) extends AnyVal { address == unspecified4 || address == unspecified6 private def isInRange4(infos: List[Ipv4Range]): Boolean = { - val ip = Ipv4.of(address.getHostAddress) + val ip = toIpv4 infos.exists(_.contains(ip)) } private def isInRange6(infos: List[Ipv6Range]): Boolean = { - val ip = Ipv6.of(address.getHostAddress) + val ip = toIpv6 infos.exists(_.contains(ip)) } + + private def toIpv4 = + Ipv4.of(address.getHostAddress) + + private def toIpv6 = + Ipv6.of(address.getHostAddress) + + private def toAbstractIp: AbstractIp[_, _] = + if (isIPv4) toIpv4 else toIpv6 + + private def toInetAddress(ip: AbstractIp[_, _]) = + InetAddress.getByName(ip.toString) + + /** Truncate the IP address to the first `prefixLength` bits. */ + def truncate(prefixLength: Int): InetAddress = + toInetAddress(toAbstractIp.lowerBoundForPrefix(prefixLength)) } object InetAddressOps { diff --git a/scalanet/ut/src/io/iohk/scalanet/peergroup/InetAddressOpsSpec.scala b/scalanet/ut/src/io/iohk/scalanet/peergroup/InetAddressOpsSpec.scala index 2381811c..b66abfae 100644 --- a/scalanet/ut/src/io/iohk/scalanet/peergroup/InetAddressOpsSpec.scala +++ b/scalanet/ut/src/io/iohk/scalanet/peergroup/InetAddressOpsSpec.scala @@ -45,4 +45,11 @@ class InetAddressOpsSpec extends FlatSpec with Matchers with Inspectors { } } } + + behavior of "truncate" + + it should "truncate the first N bits" in { + val ip = InetAddress.getByName("192.175.48.127") + ip.truncate(24) shouldBe InetAddress.getByName("192.175.48.0") + } }