Skip to content

Commit e5db69f

Browse files
author
Yuval Itzchakov
committed
Merge remote-tracking branch 'upstream/branch-2.3' into branch-2.3
2 parents 4e366f8 + 8080c93 commit e5db69f

File tree

3 files changed

+53
-17
lines changed

3 files changed

+53
-17
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -586,8 +586,9 @@ class BytesToString extends org.apache.spark.api.java.function.Function[Array[By
586586
*/
587587
private[spark] class PythonAccumulatorV2(
588588
@transient private val serverHost: String,
589-
private val serverPort: Int)
590-
extends CollectionAccumulator[Array[Byte]] {
589+
private val serverPort: Int,
590+
private val secretToken: String)
591+
extends CollectionAccumulator[Array[Byte]] with Logging{
591592

592593
Utils.checkHost(serverHost)
593594

@@ -602,12 +603,17 @@ private[spark] class PythonAccumulatorV2(
602603
private def openSocket(): Socket = synchronized {
603604
if (socket == null || socket.isClosed) {
604605
socket = new Socket(serverHost, serverPort)
606+
logInfo(s"Connected to AccumulatorServer at host: $serverHost port: $serverPort")
607+
// send the secret just for the initial authentication when opening a new connection
608+
socket.getOutputStream.write(secretToken.getBytes(StandardCharsets.UTF_8))
605609
}
606610
socket
607611
}
608612

609613
// Need to override so the types match with PythonFunction
610-
override def copyAndReset(): PythonAccumulatorV2 = new PythonAccumulatorV2(serverHost, serverPort)
614+
override def copyAndReset(): PythonAccumulatorV2 = {
615+
new PythonAccumulatorV2(serverHost, serverPort, secretToken)
616+
}
611617

612618
override def merge(other: AccumulatorV2[Array[Byte], JList[Array[Byte]]]): Unit = synchronized {
613619
val otherPythonAccumulator = other.asInstanceOf[PythonAccumulatorV2]

python/pyspark/accumulators.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -228,20 +228,49 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
228228

229229
def handle(self):
230230
from pyspark.accumulators import _accumulatorRegistry
231-
while not self.server.server_shutdown:
232-
# Poll every 1 second for new data -- don't block in case of shutdown.
233-
r, _, _ = select.select([self.rfile], [], [], 1)
234-
if self.rfile in r:
235-
num_updates = read_int(self.rfile)
236-
for _ in range(num_updates):
237-
(aid, update) = pickleSer._read_with_length(self.rfile)
238-
_accumulatorRegistry[aid] += update
239-
# Write a byte in acknowledgement
240-
self.wfile.write(struct.pack("!b", 1))
231+
auth_token = self.server.auth_token
232+
233+
def poll(func):
234+
while not self.server.server_shutdown:
235+
# Poll every 1 second for new data -- don't block in case of shutdown.
236+
r, _, _ = select.select([self.rfile], [], [], 1)
237+
if self.rfile in r:
238+
if func():
239+
break
240+
241+
def accum_updates():
242+
num_updates = read_int(self.rfile)
243+
for _ in range(num_updates):
244+
(aid, update) = pickleSer._read_with_length(self.rfile)
245+
_accumulatorRegistry[aid] += update
246+
# Write a byte in acknowledgement
247+
self.wfile.write(struct.pack("!b", 1))
248+
return False
249+
250+
def authenticate_and_accum_updates():
251+
received_token = self.rfile.read(len(auth_token))
252+
if isinstance(received_token, bytes):
253+
received_token = received_token.decode("utf-8")
254+
if (received_token == auth_token):
255+
accum_updates()
256+
# we've authenticated, we can break out of the first loop now
257+
return True
258+
else:
259+
raise Exception(
260+
"The value of the provided token to the AccumulatorServer is not correct.")
261+
262+
# first we keep polling till we've received the authentication token
263+
poll(authenticate_and_accum_updates)
264+
# now we've authenticated, don't need to check for the token anymore
265+
poll(accum_updates)
241266

242267

243268
class AccumulatorServer(SocketServer.TCPServer):
244269

270+
def __init__(self, server_address, RequestHandlerClass, auth_token):
271+
SocketServer.TCPServer.__init__(self, server_address, RequestHandlerClass)
272+
self.auth_token = auth_token
273+
245274
"""
246275
A simple TCP server that intercepts shutdown() in order to interrupt
247276
our continuous polling on the handler.
@@ -254,9 +283,9 @@ def shutdown(self):
254283
self.server_close()
255284

256285

257-
def _start_update_server():
286+
def _start_update_server(auth_token):
258287
"""Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
259-
server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler)
288+
server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler, auth_token)
260289
thread = threading.Thread(target=server.serve_forever)
261290
thread.daemon = True
262291
thread.start()

python/pyspark/context.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,9 +183,10 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
183183

184184
# Create a single Accumulator in Java that we'll send all our updates through;
185185
# they will be passed back to us through a TCP server
186-
self._accumulatorServer = accumulators._start_update_server()
186+
auth_token = self._gateway.gateway_parameters.auth_token
187+
self._accumulatorServer = accumulators._start_update_server(auth_token)
187188
(host, port) = self._accumulatorServer.server_address
188-
self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port)
189+
self._javaAccumulator = self._jvm.PythonAccumulatorV2(host, port, auth_token)
189190
self._jsc.sc().register(self._javaAccumulator)
190191

191192
self.pythonExec = os.environ.get("PYSPARK_PYTHON", 'python')

0 commit comments

Comments
 (0)