Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 4 additions & 11 deletions core/src/main/scala/org/apache/spark/network/Connection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,6 @@

package org.apache.spark.network

import org.apache.spark._
import org.apache.spark.SparkSaslServer

import scala.collection.mutable.{HashMap, Queue, ArrayBuffer}

import java.net._
import java.nio._
import java.nio.channels._
Expand All @@ -41,7 +36,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,
def this(channel_ : SocketChannel, selector_ : Selector, id_ : ConnectionId) = {
this(channel_, selector_,
ConnectionManagerId.fromSocketAddress(
channel_.socket.getRemoteSocketAddress().asInstanceOf[InetSocketAddress]), id_)
channel_.socket.getRemoteSocketAddress.asInstanceOf[InetSocketAddress]), id_)
}

channel.configureBlocking(false)
Expand Down Expand Up @@ -89,7 +84,7 @@ abstract class Connection(val channel: SocketChannel, val selector: Selector,

private def disposeSasl() {
if (sparkSaslServer != null) {
sparkSaslServer.dispose();
sparkSaslServer.dispose()
}

if (sparkSaslClient != null) {
Expand Down Expand Up @@ -328,15 +323,13 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
// Is highly unlikely unless there was an unclean close of socket, etc
registerInterest()
logInfo("Connected to [" + address + "], " + outbox.messages.size + " messages pending")
true
} catch {
case e: Exception => {
logWarning("Error finishing connection to " + address, e)
callOnExceptionCallback(e)
// ignore
return true
}
}
true
}

override def write(): Boolean = {
Expand Down Expand Up @@ -546,7 +539,7 @@ private[spark] class ReceivingConnection(
/* println("Filled buffer at " + System.currentTimeMillis) */
val bufferMessage = inbox.getMessageForChunk(currentChunk).get
if (bufferMessage.isCompletelyReceived) {
bufferMessage.flip
bufferMessage.flip()
bufferMessage.finishTime = System.currentTimeMillis
logDebug("Finished receiving [" + bufferMessage + "] from " +
"[" + getRemoteConnectionManagerId() + "] in " + bufferMessage.timeTaken)
Expand Down
114 changes: 57 additions & 57 deletions core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
def run() {
try {
while(!selectorThread.isInterrupted) {
while (! registerRequests.isEmpty) {
while (!registerRequests.isEmpty) {
val conn: SendingConnection = registerRequests.dequeue()
addListeners(conn)
conn.connect()
Expand Down Expand Up @@ -308,7 +308,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
// Some keys within the selectors list are invalid/closed. clear them.
val allKeys = selector.keys().iterator()

while (allKeys.hasNext()) {
while (allKeys.hasNext) {
val key = allKeys.next()
try {
if (! key.isValid) {
Expand Down Expand Up @@ -341,7 +341,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,

if (0 != selectedKeysCount) {
val selectedKeys = selector.selectedKeys().iterator()
while (selectedKeys.hasNext()) {
while (selectedKeys.hasNext) {
val key = selectedKeys.next
selectedKeys.remove()
try {
Expand Down Expand Up @@ -419,62 +419,63 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
connectionsByKey -= connection.key

try {
if (connection.isInstanceOf[SendingConnection]) {
val sendingConnection = connection.asInstanceOf[SendingConnection]
val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
logInfo("Removing SendingConnection to " + sendingConnectionManagerId)

connectionsById -= sendingConnectionManagerId
connectionsAwaitingSasl -= connection.connectionId
connection match {
case sendingConnection: SendingConnection =>
val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
logInfo("Removing SendingConnection to " + sendingConnectionManagerId)

connectionsById -= sendingConnectionManagerId
connectionsAwaitingSasl -= connection.connectionId

messageStatuses.synchronized {
messageStatuses.values.filter(_.connectionManagerId == sendingConnectionManagerId)
.foreach(status => {
logInfo("Notifying " + status)
status.synchronized {
status.attempted = true
status.acked = false
status.markDone()
}
})

messageStatuses.synchronized {
messageStatuses
.values.filter(_.connectionManagerId == sendingConnectionManagerId).foreach(status => {
logInfo("Notifying " + status)
status.synchronized {
status.attempted = true
status.acked = false
status.markDone()
}
messageStatuses.retain((i, status) => {
status.connectionManagerId != sendingConnectionManagerId
})
}
case receivingConnection: ReceivingConnection =>
val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId()
logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)

messageStatuses.retain((i, status) => {
status.connectionManagerId != sendingConnectionManagerId
})
}
} else if (connection.isInstanceOf[ReceivingConnection]) {
val receivingConnection = connection.asInstanceOf[ReceivingConnection]
val remoteConnectionManagerId = receivingConnection.getRemoteConnectionManagerId()
logInfo("Removing ReceivingConnection to " + remoteConnectionManagerId)

val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId)
if (! sendingConnectionOpt.isDefined) {
logError("Corresponding SendingConnectionManagerId not found")
return
}
val sendingConnectionOpt = connectionsById.get(remoteConnectionManagerId)
if (!sendingConnectionOpt.isDefined) {
logError("Corresponding SendingConnectionManagerId not found")
return
}

val sendingConnection = sendingConnectionOpt.get
connectionsById -= remoteConnectionManagerId
sendingConnection.close()
val sendingConnection = sendingConnectionOpt.get
connectionsById -= remoteConnectionManagerId
sendingConnection.close()

val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()
val sendingConnectionManagerId = sendingConnection.getRemoteConnectionManagerId()

assert (sendingConnectionManagerId == remoteConnectionManagerId)
assert(sendingConnectionManagerId == remoteConnectionManagerId)

messageStatuses.synchronized {
for (s <- messageStatuses.values if s.connectionManagerId == sendingConnectionManagerId) {
logInfo("Notifying " + s)
s.synchronized {
s.attempted = true
s.acked = false
s.markDone()
messageStatuses.synchronized {
for (s <- messageStatuses.values
if s.connectionManagerId == sendingConnectionManagerId) {
logInfo("Notifying " + s)
s.synchronized {
s.attempted = true
s.acked = false
s.markDone()
}
}
}

messageStatuses.retain((i, status) => {
status.connectionManagerId != sendingConnectionManagerId
})
}
messageStatuses.retain((i, status) => {
status.connectionManagerId != sendingConnectionManagerId
})
}
case _ => logError("Unsupported type of connection.")
}
} finally {
// So that the selection keys can be removed.
Expand Down Expand Up @@ -517,13 +518,13 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
logDebug("Client sasl completed for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
waitingConn.getAuthenticated().synchronized {
waitingConn.getAuthenticated().notifyAll();
waitingConn.getAuthenticated().notifyAll()
}
return
} else {
var replyToken : Array[Byte] = null
try {
replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken);
replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken)
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
Expand All @@ -533,7 +534,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
return
}
val securityMsgResp = SecurityMessage.fromResponse(replyToken,
securityMsg.getConnectionId.toString())
securityMsg.getConnectionId.toString)
val message = securityMsgResp.toBufferMessage
if (message == null) throw new Exception("Error creating security message")
sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
Expand Down Expand Up @@ -630,13 +631,13 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
case bufferMessage: BufferMessage => {
if (authEnabled) {
val res = handleAuthentication(connection, bufferMessage)
if (res == true) {
if (res) {
// message was security negotiation so skip the rest
logDebug("After handleAuth result was true, returning")
return
}
}
if (bufferMessage.hasAckId) {
if (bufferMessage.hasAckId()) {
val sentMessageStatus = messageStatuses.synchronized {
messageStatuses.get(bufferMessage.ackId) match {
case Some(status) => {
Expand All @@ -646,7 +647,6 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
case None => {
throw new Exception("Could not find reference for received ack message " +
message.id)
null
}
}
}
Expand All @@ -668,7 +668,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
if (ackMessage.isDefined) {
if (!ackMessage.get.isInstanceOf[BufferMessage]) {
logDebug("Response to " + bufferMessage + " is not a buffer message, it is of type "
+ ackMessage.get.getClass())
+ ackMessage.get.getClass)
} else if (!ackMessage.get.asInstanceOf[BufferMessage].hasAckId) {
logDebug("Response to " + bufferMessage + " does not have ack id set")
ackMessage.get.asInstanceOf[BufferMessage].ackId = bufferMessage.id
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ private[spark] case class ConnectionManagerId(host: String, port: Int) {

private[spark] object ConnectionManagerId {
def fromSocketAddress(socketAddress: InetSocketAddress): ConnectionManagerId = {
new ConnectionManagerId(socketAddress.getHostName(), socketAddress.getPort())
new ConnectionManagerId(socketAddress.getHostName, socketAddress.getPort)
}
}