Skip to content

Commit 2b44cf1

Browse files
rxinaarondav
authored andcommitted
Added more documentation.
1 parent 1760d32 commit 2b44cf1

File tree

6 files changed

+80
-58
lines changed

6 files changed

+80
-58
lines changed

core/src/main/scala/org/apache/spark/network/netty/BlockClient.scala

Lines changed: 14 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -19,68 +19,35 @@ package org.apache.spark.network.netty
1919

2020
import java.util.concurrent.TimeoutException
2121

22-
import io.netty.bootstrap.Bootstrap
23-
import io.netty.buffer.PooledByteBufAllocator
24-
import io.netty.channel.socket.SocketChannel
25-
import io.netty.channel.{ChannelFuture, ChannelFutureListener, ChannelInitializer, ChannelOption}
22+
import io.netty.channel.{ChannelFuture, ChannelFutureListener}
2623

2724
import org.apache.spark.Logging
2825
import org.apache.spark.network.BlockFetchingListener
2926

3027

3128
/**
32-
* Client for [[NettyBlockTransferService]]. Use [[BlockClientFactory]] to
33-
* instantiate this client.
29+
* Client for [[NettyBlockTransferService]]. The connection to server must have been established
30+
* using [[BlockClientFactory]] before instantiating this.
3431
*
35-
* The constructor blocks until a connection is successfully established.
32+
* This class is used to make requests to the server , while [[BlockClientHandler]] is responsible
33+
* for handling responses from the server.
3634
*
3735
* Concurrency: thread safe and can be called from multiple threads.
36+
*
37+
* @param cf the ChannelFuture for the connection.
38+
* @param handler [[BlockClientHandler]] for handling outstanding requests.
3839
*/
3940
@throws[TimeoutException]
4041
private[netty]
41-
class BlockClient(factory: BlockClientFactory, hostname: String, port: Int)
42-
extends Logging {
43-
44-
private val handler = new BlockClientHandler
45-
private val encoder = new ClientRequestEncoder
46-
private val decoder = new ServerResponseDecoder
47-
48-
/** Netty Bootstrap for creating the TCP connection. */
49-
private val bootstrap: Bootstrap = {
50-
val b = new Bootstrap
51-
b.group(factory.workerGroup)
52-
.channel(factory.socketChannelClass)
53-
// Use pooled buffers to reduce temporary buffer allocation
54-
.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
55-
// Disable Nagle's Algorithm since we don't want packets to wait
56-
.option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE)
57-
.option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE)
58-
.option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs)
59-
60-
b.handler(new ChannelInitializer[SocketChannel] {
61-
override def initChannel(ch: SocketChannel): Unit = {
62-
ch.pipeline
63-
.addLast("clientRequestEncoder", encoder)
64-
.addLast("frameDecoder", ProtocolUtils.createFrameDecoder())
65-
.addLast("serverResponseDecoder", decoder)
66-
.addLast("handler", handler)
67-
}
68-
})
69-
b
70-
}
42+
class BlockClient(cf: ChannelFuture, handler: BlockClientHandler) extends Logging {
7143

72-
/** Netty ChannelFuture for the connection. */
73-
private val cf: ChannelFuture = bootstrap.connect(hostname, port)
74-
if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) {
75-
throw new TimeoutException(
76-
s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)")
77-
}
44+
private[this] val serverAddr = cf.channel().remoteAddress().toString
7845

7946
/**
8047
* Ask the remote server for a sequence of blocks, and execute the callback.
8148
*
8249
* Note that this is asynchronous and returns immediately. Upstream caller should throttle the
83-
* rate of fetching; otherwise we could run out of memory.
50+
* rate of fetching; otherwise we could run out of memory due to large outstanding fetches.
8451
*
8552
* @param blockIds sequence of block ids to fetch.
8653
* @param listener callback to fire on fetch success / failure.
@@ -89,7 +56,7 @@ class BlockClient(factory: BlockClientFactory, hostname: String, port: Int)
8956
var startTime: Long = 0
9057
logTrace {
9158
startTime = System.nanoTime
92-
s"Sending request $blockIds to $hostname:$port"
59+
s"Sending request $blockIds to $serverAddr"
9360
}
9461

9562
blockIds.foreach { blockId =>
@@ -101,12 +68,12 @@ class BlockClient(factory: BlockClientFactory, hostname: String, port: Int)
10168
if (future.isSuccess) {
10269
logTrace {
10370
val timeTaken = (System.nanoTime - startTime).toDouble / 1000000
104-
s"Sending request $blockIds to $hostname:$port took $timeTaken ms"
71+
s"Sending request $blockIds to $serverAddr took $timeTaken ms"
10572
}
10673
} else {
10774
// Fail all blocks.
10875
val errorMsg =
109-
s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}"
76+
s"Failed to send request $blockIds to $serverAddr: ${future.cause.getMessage}"
11077
logError(errorMsg, future.cause)
11178
blockIds.foreach { blockId =>
11279
handler.removeRequest(blockId)

core/src/main/scala/org/apache/spark/network/netty/BlockClientFactory.scala

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@
1717

1818
package org.apache.spark.network.netty
1919

20+
import java.util.concurrent.TimeoutException
21+
22+
import io.netty.bootstrap.Bootstrap
23+
import io.netty.buffer.PooledByteBufAllocator
24+
import io.netty.channel._
2025
import io.netty.channel.epoll.{Epoll, EpollEventLoopGroup, EpollSocketChannel}
2126
import io.netty.channel.nio.NioEventLoopGroup
2227
import io.netty.channel.oio.OioEventLoopGroup
28+
import io.netty.channel.socket.SocketChannel
2329
import io.netty.channel.socket.nio.NioSocketChannel
2430
import io.netty.channel.socket.oio.OioSocketChannel
25-
import io.netty.channel.{Channel, EventLoopGroup}
2631

2732
import org.apache.spark.SparkConf
2833
import org.apache.spark.util.Utils
@@ -38,12 +43,16 @@ class BlockClientFactory(val conf: NettyConfig) {
3843
def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf))
3944

4045
/** A thread factory so the threads are named (for debugging). */
41-
private[netty] val threadFactory = Utils.namedThreadFactory("spark-shuffle-client")
46+
private[netty] val threadFactory = Utils.namedThreadFactory("spark-netty-client")
4247

4348
/** The following two are instantiated by the [[init]] method, depending ioMode. */
4449
private[netty] var socketChannelClass: Class[_ <: Channel] = _
4550
private[netty] var workerGroup: EventLoopGroup = _
4651

52+
// The encoders are stateless and can be shared among multiple clients.
53+
private[this] val encoder = new ClientRequestEncoder
54+
private[this] val decoder = new ServerResponseDecoder
55+
4756
init()
4857

4958
/** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */
@@ -78,7 +87,36 @@ class BlockClientFactory(val conf: NettyConfig) {
7887
* Concurrency: This method is safe to call from multiple threads.
7988
*/
8089
def createClient(remoteHost: String, remotePort: Int): BlockClient = {
81-
new BlockClient(this, remoteHost, remotePort)
90+
val handler = new BlockClientHandler
91+
92+
val bootstrap = new Bootstrap
93+
bootstrap.group(workerGroup)
94+
.channel(socketChannelClass)
95+
// Use pooled buffers to reduce temporary buffer allocation
96+
.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
97+
// Disable Nagle's Algorithm since we don't want packets to wait
98+
.option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE)
99+
.option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE)
100+
.option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectTimeoutMs)
101+
102+
bootstrap.handler(new ChannelInitializer[SocketChannel] {
103+
override def initChannel(ch: SocketChannel): Unit = {
104+
ch.pipeline
105+
.addLast("clientRequestEncoder", encoder)
106+
.addLast("frameDecoder", ProtocolUtils.createFrameDecoder())
107+
.addLast("serverResponseDecoder", decoder)
108+
.addLast("handler", handler)
109+
}
110+
})
111+
112+
// Connect to the remote server
113+
val cf: ChannelFuture = bootstrap.connect(remoteHost, remotePort)
114+
if (!cf.awaitUninterruptibly(conf.connectTimeoutMs)) {
115+
throw new TimeoutException(
116+
s"Connecting to $remoteHost:$remotePort timed out (${conf.connectTimeoutMs} ms)")
117+
}
118+
119+
new BlockClient(cf, handler)
82120
}
83121

84122
def stop(): Unit = {

core/src/main/scala/org/apache/spark/network/netty/BlockClientHandler.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,16 @@ import org.apache.spark.network.BlockFetchingListener
2424

2525

2626
/**
27-
* Handler that processes server responses.
27+
* Handler that processes server responses, in response to requests issued from [[BlockClient]].
28+
* It works by tracking the list of outstanding requests (and their callbacks).
2829
*
2930
* Concurrency: thread safe and can be called from multiple threads.
3031
*/
3132
private[netty]
3233
class BlockClientHandler extends SimpleChannelInboundHandler[ServerResponse] with Logging {
3334

3435
/** Tracks the list of outstanding requests and their listeners on success/failure. */
35-
private val outstandingRequests = java.util.Collections.synchronizedMap {
36+
private[this] val outstandingRequests = java.util.Collections.synchronizedMap {
3637
new java.util.HashMap[String, BlockFetchingListener]
3738
}
3839

core/src/main/scala/org/apache/spark/network/netty/BlockServer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ class BlockServer(conf: NettyConfig, dataProvider: BlockDataManager) extends Log
5858
/** Initialize the server. */
5959
private def init(): Unit = {
6060
bootstrap = new ServerBootstrap
61-
val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss")
62-
val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker")
61+
val bossThreadFactory = Utils.namedThreadFactory("spark-netty-server-boss")
62+
val workerThreadFactory = Utils.namedThreadFactory("spark-netty-server-worker")
6363

6464
// Use only one thread to accept connections, and 2 * num_cores for worker.
6565
def initNio(): Unit = {

core/src/main/scala/org/apache/spark/network/netty/protocol.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,37 +28,50 @@ import org.apache.spark.Logging
2828
import org.apache.spark.network.{NettyByteBufManagedBuffer, ManagedBuffer}
2929

3030

31+
/** Messages from the client to the server. */
3132
sealed trait ClientRequest {
3233
def id: Byte
3334
}
3435

36+
/**
37+
* Request to fetch a sequence of blocks from the server. A single [[BlockFetchRequest]] can
38+
* correspond to multiple [[ServerResponse]]s.
39+
*/
3540
final case class BlockFetchRequest(blocks: Seq[String]) extends ClientRequest {
3641
override def id = 0
3742
}
3843

44+
/**
45+
* Request to upload a block to the server. Currently the server does not ack the upload request.
46+
*/
3947
final case class BlockUploadRequest(blockId: String, data: ManagedBuffer) extends ClientRequest {
4048
require(blockId.length <= Byte.MaxValue)
4149
override def id = 1
4250
}
4351

4452

53+
/** Messages from server to client (usually in response to some [[ClientRequest]]. */
4554
sealed trait ServerResponse {
4655
def id: Byte
4756
}
4857

58+
/** Response to [[BlockFetchRequest]] when a block exists and has been successfully fetched. */
4959
final case class BlockFetchSuccess(blockId: String, data: ManagedBuffer) extends ServerResponse {
5060
require(blockId.length <= Byte.MaxValue)
5161
override def id = 0
5262
}
5363

64+
/** Response to [[BlockFetchRequest]] when there is an error fetching the block. */
5465
final case class BlockFetchFailure(blockId: String, error: String) extends ServerResponse {
5566
require(blockId.length <= Byte.MaxValue)
5667
override def id = 1
5768
}
5869

5970

6071
/**
61-
* Encoder used by the client side to encode client-to-server responses.
72+
* Encoder for [[ClientRequest]] used in client side.
73+
*
74+
* This encoder is stateless so it is safe to be shared by multiple threads.
6275
*/
6376
@Sharable
6477
final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest] {
@@ -109,6 +122,7 @@ final class ClientRequestEncoder extends MessageToMessageEncoder[ClientRequest]
109122

110123
/**
111124
* Decoder in the server side to decode client requests.
125+
* This decoder is stateless so it is safe to be shared by multiple threads.
112126
*
113127
* This assumes the inbound messages have been processed by a frame decoder created by
114128
* [[ProtocolUtils.createFrameDecoder()]].
@@ -138,6 +152,7 @@ final class ClientRequestDecoder extends MessageToMessageDecoder[ByteBuf] {
138152

139153
/**
140154
* Encoder used by the server side to encode server-to-client responses.
155+
* This encoder is stateless so it is safe to be shared by multiple threads.
141156
*/
142157
@Sharable
143158
final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse] with Logging {
@@ -190,6 +205,7 @@ final class ServerResponseEncoder extends MessageToMessageEncoder[ServerResponse
190205

191206
/**
192207
* Decoder in the client side to decode server responses.
208+
* This decoder is stateless so it is safe to be shared by multiple threads.
193209
*
194210
* This assumes the inbound messages have been processed by a frame decoder created by
195211
* [[ProtocolUtils.createFrameDecoder()]].
@@ -229,6 +245,7 @@ private[netty] object ProtocolUtils {
229245
new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 8, -8, 8)
230246
}
231247

248+
// TODO(rxin): Make sure these work for all charsets.
232249
def readBlockId(in: ByteBuf): String = {
233250
val numBytesToRead = in.readByte().toInt
234251
val bytes = new Array[Byte](numBytesToRead)

core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import org.apache.spark.storage.StorageLevel
3434

3535

3636
/**
37-
* Test suite that makes sure the server and the client implementations share the same protocol.
37+
* Test cases that create real clients and servers and connect.
3838
*/
3939
class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {
4040

@@ -93,8 +93,7 @@ class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {
9393
/** A ByteBuf for file_block */
9494
lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25)
9595

96-
def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ManagedBuffer], Set[String]) =
97-
{
96+
def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ManagedBuffer], Set[String]) = {
9897
val client = clientFactory.createClient(server.hostName, server.port)
9998
val sem = new Semaphore(0)
10099
val receivedBlockIds = Collections.synchronizedSet(new HashSet[String])

0 commit comments

Comments
 (0)