From 8d2f08cfdb62ea539f35e0a8168d104212a13988 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 3 Sep 2014 17:08:35 -0700 Subject: [PATCH 01/13] reuse python worker --- .../scala/org/apache/spark/SparkEnv.scala | 8 ++++ .../apache/spark/api/python/PythonRDD.scala | 15 +++---- .../api/python/PythonWorkerFactory.scala | 18 ++++++++ docs/configuration.md | 10 +++++ python/pyspark/daemon.py | 44 +++++++++---------- python/pyspark/serializers.py | 4 ++ python/run-tests | 2 +- 7 files changed, 67 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 72716567ca99b..52c8f0ad29212 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -105,6 +105,14 @@ class SparkEnv ( pythonWorkers.get(key).foreach(_.stopWorker(worker)) } } + + private[spark] + def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) { + synchronized { + val key = (pythonExec, envVars) + pythonWorkers.get(key).foreach(_.releaseWorker(worker)) + } + } } object SparkEnv extends Logging { diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ae8010300a500..86ffdd9f80bef 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -52,6 +52,7 @@ private[spark] class PythonRDD( extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) + val reuse_worker = conf.getBoolean("spark.python.reuse.worker", true) override def getPartitions = parent.partitions @@ -63,6 +64,9 @@ private[spark] class PythonRDD( val localdir = env.blockManager.diskBlockManager.localDirs.map( f => f.getPath()).mkString(",") envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread + if (reuse_worker) { + envVars += ("SPARK_REUSE_WORKER" -> "1") + } val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) // Start a thread to feed the process input from our parent's iterator @@ -70,13 +74,7 @@ private[spark] class PythonRDD( context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() - - // Cleanup the worker socket. This will also cause the Python worker to exit. - try { - worker.close() - } catch { - case e: Exception => logWarning("Failed to close worker socket", e) - } + env.releasePythonWorker(pythonExec, envVars.toMap, worker) } writerThread.start() @@ -207,6 +205,7 @@ private[spark] class PythonRDD( dataOut.write(command) // Data values PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) dataOut.flush() } catch { case e: Exception if context.isCompleted || context.isInterrupted => @@ -216,8 +215,6 @@ private[spark] class PythonRDD( // We must avoid throwing exceptions here, because the thread uncaught exception handler // will kill the whole executor (see org.apache.spark.executor.Executor). _exception = e - } finally { - Try(worker.shutdownOutput()) // kill Python worker process } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 4c4796f6c59ba..b25016be4bea5 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -41,6 +41,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 var daemonWorkers = new mutable.WeakHashMap[Socket, Int]() + var idleWorkers = new mutable.Queue[Socket]() var simpleWorkers = new mutable.WeakHashMap[Socket, Process]() @@ -51,6 +52,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String def create(): Socket = { if (useDaemon) { + if (idleWorkers.length > 0) { + return idleWorkers.dequeue() + } createThroughDaemon() } else { createSimpleWorker() @@ -235,6 +239,20 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } worker.close() } + + def releaseWorker(worker: Socket) { + if (useDaemon && envVars.get("SPARK_REUSE_WORKER").isDefined) { + idleWorkers.enqueue(worker) + } else { + // Cleanup the worker socket. This will also cause the Python worker to exit. + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } } private object PythonWorkerFactory { diff --git a/docs/configuration.md b/docs/configuration.md index 65a422caabb7e..fc6cea16460ce 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -206,6 +206,16 @@ Apart from these, the following properties are also available, and may be useful used during aggregation goes above this amount, it will spill the data into disks. + + spark.python.worker.reuse + true + + Reuse Python worker or not. If yes, it will use a fixed number of Python workers, + does not need to fork() a Python process for every tasks. It will be very useful + if there is large broadcast, then the broadcast will not be needed to transfered + from JVM to Python worker for every task. + + spark.executorEnv.[EnvironmentVariableName] (none) diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 22ab8d30c0ae3..64d6202acb27d 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -23,6 +23,7 @@ import sys import traceback import time +import gc from errno import EINTR, ECHILD, EAGAIN from socket import AF_INET, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN @@ -42,25 +43,10 @@ def worker(sock): """ Called by a worker process after the fork(). """ - # Redirect stdout to stderr - os.dup2(2, 1) - sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1 - signal.signal(SIGHUP, SIG_DFL) signal.signal(SIGCHLD, SIG_DFL) signal.signal(SIGTERM, SIG_DFL) - # Blocks until the socket is closed by draining the input stream - # until it raises an exception or returns EOF. - def waitSocketClose(sock): - try: - while True: - # Empty string is returned upon EOF (and only then). - if sock.recv(4096) == '': - return - except: - pass - # Read the socket using fdopen instead of socket.makefile() because the latter # seems to be very slow; note that we need to dup() the file descriptor because # otherwise writes also cause a seek that makes us miss data on the read side. @@ -68,17 +54,13 @@ def waitSocketClose(sock): outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536) exit_code = 0 try: - # Acknowledge that the fork was successful - write_int(os.getpid(), outfile) - outfile.flush() worker_main(infile, outfile) except SystemExit as exc: - exit_code = exc.code + exit_code = compute_real_exit_code(exc.code) finally: outfile.flush() - # The Scala side will close the socket upon task completion. - waitSocketClose(sock) - os._exit(compute_real_exit_code(exit_code)) + if exit_code: + os._exit(exit_code) # Cleanup zombie children @@ -102,6 +84,7 @@ def manager(): listen_sock.listen(max(1024, SOMAXCONN)) listen_host, listen_port = listen_sock.getsockname() write_int(listen_port, sys.stdout) + sys.stdout.flush() def shutdown(code): signal.signal(SIGTERM, SIG_DFL) @@ -114,8 +97,9 @@ def handle_sigterm(*args): signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP + reuse = os.environ.get("SPARK_REUSE_WORKER") + # Initialization complete - sys.stdout.close() try: while True: try: @@ -167,7 +151,19 @@ def handle_sigterm(*args): # in child process listen_sock.close() try: - worker(sock) + # Acknowledge that the fork was successful + outfile = sock.makefile("w") + write_int(os.getpid(), outfile) + outfile.flush() + outfile.close() + while True: + worker(sock) + if not reuse: + # wait for closing + while sock.recv(1024): + pass + break + gc.collect() except: traceback.print_exc() os._exit(1) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index fc49aa42dbaf9..64a058cf0826b 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -144,6 +144,8 @@ def _write_with_length(self, obj, stream): def _read_with_length(self, stream): length = read_int(stream) + if length == SpecialLengths.END_OF_DATA_SECTION: + raise EOFError obj = stream.read(length) if obj == "": raise EOFError @@ -431,6 +433,8 @@ class UTF8Deserializer(Serializer): def loads(self, stream): length = read_int(stream) + if length == SpecialLengths.END_OF_DATA_SECTION: + raise EOFError return stream.read(length).decode('utf8') def load_stream(self, stream): diff --git a/python/run-tests b/python/run-tests index 7b1ee3e1cddba..f2d017ca99c8a 100755 --- a/python/run-tests +++ b/python/run-tests @@ -50,7 +50,7 @@ echo "Running PySpark tests. Output is in python/unit-tests.log." # Try to test with Python 2.6, since that's the minimum version that we support: if [ $(which python2.6) ]; then - export PYSPARK_PYTHON="python2.6" + export PYSPARK_PYTHON="pypy" fi echo "Testing with Python version:" From 6123d0f8e55dd78039f5c59023b12ab54651fe12 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 3 Sep 2014 18:06:45 -0700 Subject: [PATCH 02/13] track broadcasts for each worker --- .../apache/spark/api/python/PythonRDD.scala | 30 ++++++++++++++++--- python/pyspark/worker.py | 7 +++-- 2 files changed, 31 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 86ffdd9f80bef..7c7b69c8e1708 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -23,6 +23,7 @@ import java.nio.charset.Charset import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ +import scala.collection.mutable import scala.language.existentials import scala.reflect.ClassTag import scala.util.{Try, Success, Failure} @@ -193,11 +194,26 @@ private[spark] class PythonRDD( PythonRDD.writeUTF(include, dataOut) } // Broadcast variables - dataOut.writeInt(broadcastVars.length) + val bids = PythonRDD.getWorkerBroadcasts(worker) + val nbids = broadcastVars.map(_.id).toSet + // number of different broadcasts + val cnt = bids.diff(nbids).size + nbids.diff(bids).size + dataOut.writeInt(cnt) + for (bid <- bids) { + if (!nbids.contains(bid)) { + // remove the broadcast from worker + dataOut.writeLong(-bid) + bids.remove(bid) + } + } for (broadcast <- broadcastVars) { - dataOut.writeLong(broadcast.id) - dataOut.writeInt(broadcast.value.length) - dataOut.write(broadcast.value) + if (!bids.contains(broadcast.id)) { + // send new broadcast + dataOut.writeLong(broadcast.id) + dataOut.writeInt(broadcast.value.length) + dataOut.write(broadcast.value) + bids.add(broadcast.id) + } } dataOut.flush() // Serialized command: @@ -275,6 +291,12 @@ private object SpecialLengths { private[spark] object PythonRDD extends Logging { val UTF8 = Charset.forName("UTF-8") + // remember the broadcasts sent to each worker + private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() + private def getWorkerBroadcasts(worker: Socket) = { + workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) + } + /** * Adapter for calling SparkContext#runJob from Python. * diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 6805063e06798..e7d236888ac69 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -69,8 +69,11 @@ def main(infile, outfile): ser = CompressedSerializer(pickleSer) for _ in range(num_broadcast_variables): bid = read_long(infile) - value = ser._read_with_length(infile) - _broadcastRegistry[bid] = Broadcast(bid, value) + if bid > 0: + value = ser._read_with_length(infile) + _broadcastRegistry[bid] = Broadcast(bid, value) + else: + _broadcastRegistry.pop(-bid, None) command = pickleSer._read_with_length(infile) (func, deserializer, serializer) = command From ace2917a98a38c5838f37583ddb3d1fdfaa5ae2c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 3 Sep 2014 19:01:30 -0700 Subject: [PATCH 03/13] kill python worker after timeout --- .../api/python/PythonWorkerFactory.scala | 45 +++++++++++++++++-- python/run-tests | 2 +- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index b25016be4bea5..8d494929c4c8c 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -40,8 +40,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String var daemon: Process = null val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) var daemonPort: Int = 0 - var daemonWorkers = new mutable.WeakHashMap[Socket, Int]() - var idleWorkers = new mutable.Queue[Socket]() + val daemonWorkers = new mutable.WeakHashMap[Socket, Int]() + val idleWorkers = new mutable.Queue[Socket]() + var lastActivity = 0L + new MonitorThread().start() var simpleWorkers = new mutable.WeakHashMap[Socket, Process]() @@ -52,8 +54,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String def create(): Socket = { if (useDaemon) { - if (idleWorkers.length > 0) { - return idleWorkers.dequeue() + idleWorkers.synchronized { + if (idleWorkers.size > 0) { + return idleWorkers.dequeue() + } } createThroughDaemon() } else { @@ -203,6 +207,35 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } } + /** + * Monitor all the idle workers, kill them after timeout. + */ + private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") { + + setDaemon(true) + + override def run() { + while (true) { + idleWorkers.synchronized { + if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) { + while (idleWorkers.length > 0) { + val worker = idleWorkers.dequeue() + try { + // the Python worker will exit after closing the socket + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + lastActivity = System.currentTimeMillis() + } + } + Thread.sleep(10000) + } + } + } + private def stopDaemon() { synchronized { if (useDaemon) { @@ -242,7 +275,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String def releaseWorker(worker: Socket) { if (useDaemon && envVars.get("SPARK_REUSE_WORKER").isDefined) { + idleWorkers.synchronized { + lastActivity = System.currentTimeMillis() idleWorkers.enqueue(worker) + } } else { // Cleanup the worker socket. This will also cause the Python worker to exit. try { @@ -257,4 +293,5 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String private object PythonWorkerFactory { val PROCESS_WAIT_TIMEOUT_MS = 10000 + val IDLE_WORKER_TIMEOUT_MS = 60000 } diff --git a/python/run-tests b/python/run-tests index f2d017ca99c8a..7b1ee3e1cddba 100755 --- a/python/run-tests +++ b/python/run-tests @@ -50,7 +50,7 @@ echo "Running PySpark tests. Output is in python/unit-tests.log." # Try to test with Python 2.6, since that's the minimum version that we support: if [ $(which python2.6) ]; then - export PYSPARK_PYTHON="pypy" + export PYSPARK_PYTHON="python2.6" fi echo "Testing with Python version:" From 583716ee156d0d7eb87220295a624bfc1427032a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 3 Sep 2014 22:33:04 -0700 Subject: [PATCH 04/13] only reuse completed and not interrupted worker --- .../scala/org/apache/spark/api/python/PythonRDD.scala | 11 ++++++++++- .../apache/spark/api/python/PythonWorkerFactory.scala | 9 +++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 7c7b69c8e1708..a9bd57832342b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -75,7 +75,16 @@ private[spark] class PythonRDD( context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() - env.releasePythonWorker(pythonExec, envVars.toMap, worker) + if (!context.isInterrupted) { + env.releasePythonWorker(pythonExec, envVars.toMap, worker) + } else { + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } } writerThread.start() diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 8d494929c4c8c..118380c369db9 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -239,6 +239,15 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String private def stopDaemon() { synchronized { if (useDaemon) { + while (idleWorkers.length > 0) { + val worker = idleWorkers.dequeue() + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } // Request shutdown of existing daemon by sending SIGTERM if (daemon != null) { daemon.destroy() From e0131a22be8a053408297b5677d4ef6b5736ce37 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 6 Sep 2014 17:26:36 -0700 Subject: [PATCH 05/13] fix name of config --- core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index a9bd57832342b..03f9c568923b4 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -53,7 +53,7 @@ private[spark] class PythonRDD( extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) - val reuse_worker = conf.getBoolean("spark.python.reuse.worker", true) + val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) override def getPartitions = parent.partitions From 6325fc16bbf71a077f77af0938976eee7732a813 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Sat, 6 Sep 2014 19:29:46 -0700 Subject: [PATCH 06/13] bugfix: bid >= 0 --- .../main/scala/org/apache/spark/api/python/PythonRDD.scala | 2 +- python/pyspark/worker.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 03f9c568923b4..ba9adce330db0 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -211,7 +211,7 @@ private[spark] class PythonRDD( for (bid <- bids) { if (!nbids.contains(bid)) { // remove the broadcast from worker - dataOut.writeLong(-bid) + dataOut.writeLong(- bid - 1) // bid >= 0 bids.remove(bid) } } diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index e7d236888ac69..77254f599aa04 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -69,11 +69,12 @@ def main(infile, outfile): ser = CompressedSerializer(pickleSer) for _ in range(num_broadcast_variables): bid = read_long(infile) - if bid > 0: + if bid >= 0: value = ser._read_with_length(infile) _broadcastRegistry[bid] = Broadcast(bid, value) else: - _broadcastRegistry.pop(-bid, None) + bid = - bid - 1 + _broadcastRegistry.pop(bid, None) command = pickleSer._read_with_length(infile) (func, deserializer, serializer) = command From 8911f44f69d1675ee43956b47276ca05c52a3ee0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 8 Sep 2014 14:28:52 -0700 Subject: [PATCH 07/13] synchronized getWorkerBroadcasts() --- .../main/scala/org/apache/spark/api/python/PythonRDD.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ba9adce330db0..8eb749f47e4a1 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -303,7 +303,9 @@ private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]() private def getWorkerBroadcasts(worker: Socket) = { - workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) + synchronized { + workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]()) + } } /** From ac3206ea884022733c15f1d46fe1296ba8e92f5f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 8 Sep 2014 14:42:07 -0700 Subject: [PATCH 08/13] renaming --- .../org/apache/spark/api/python/PythonRDD.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 8eb749f47e4a1..62ecf2dede3d2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -203,25 +203,25 @@ private[spark] class PythonRDD( PythonRDD.writeUTF(include, dataOut) } // Broadcast variables - val bids = PythonRDD.getWorkerBroadcasts(worker) - val nbids = broadcastVars.map(_.id).toSet + val oldBids = PythonRDD.getWorkerBroadcasts(worker) + val newBids = broadcastVars.map(_.id).toSet // number of different broadcasts - val cnt = bids.diff(nbids).size + nbids.diff(bids).size + val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size dataOut.writeInt(cnt) - for (bid <- bids) { - if (!nbids.contains(bid)) { + for (bid <- oldBids) { + if (!newBids.contains(bid)) { // remove the broadcast from worker dataOut.writeLong(- bid - 1) // bid >= 0 - bids.remove(bid) + oldBids.remove(bid) } } for (broadcast <- broadcastVars) { - if (!bids.contains(broadcast.id)) { + if (!oldBids.contains(broadcast.id)) { // send new broadcast dataOut.writeLong(broadcast.id) dataOut.writeInt(broadcast.value.length) dataOut.write(broadcast.value) - bids.add(broadcast.id) + oldBids.add(broadcast.id) } } dataOut.flush() From 7abb22454175406c225f00950804e9fee5d56ed0 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 8 Sep 2014 15:11:11 -0700 Subject: [PATCH 09/13] refactor: sychronized with itself --- .../api/python/PythonWorkerFactory.scala | 65 +++++++++---------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 118380c369db9..5e73747f2e8b8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -54,7 +54,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String def create(): Socket = { if (useDaemon) { - idleWorkers.synchronized { + synchronized { if (idleWorkers.size > 0) { return idleWorkers.dequeue() } @@ -216,18 +216,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String override def run() { while (true) { - idleWorkers.synchronized { + synchronized { if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) { - while (idleWorkers.length > 0) { - val worker = idleWorkers.dequeue() - try { - // the Python worker will exit after closing the socket - worker.close() - } catch { - case e: Exception => - logWarning("Failed to close worker socket", e) - } - } + cleanupIdleWorkers() lastActivity = System.currentTimeMillis() } } @@ -236,18 +227,24 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } } + private def cleanupIdleWorkers() { + while (idleWorkers.length > 0) { + val worker = idleWorkers.dequeue() + try { + // the worker will exit after closing the socket + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } + private def stopDaemon() { synchronized { if (useDaemon) { - while (idleWorkers.length > 0) { - val worker = idleWorkers.dequeue() - try { - worker.close() - } catch { - case e: Exception => - logWarning("Failed to close worker socket", e) - } - } + cleanupIdleWorkers() + // Request shutdown of existing daemon by sending SIGTERM if (daemon != null) { daemon.destroy() @@ -266,25 +263,27 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } def stopWorker(worker: Socket) { - if (useDaemon) { - if (daemon != null) { - daemonWorkers.get(worker).foreach { pid => - // tell daemon to kill worker by pid - val output = new DataOutputStream(daemon.getOutputStream) - output.writeInt(pid) - output.flush() - daemon.getOutputStream.flush() + synchronized { + if (useDaemon) { + if (daemon != null) { + daemonWorkers.get(worker).foreach { pid => + // tell daemon to kill worker by pid + val output = new DataOutputStream(daemon.getOutputStream) + output.writeInt(pid) + output.flush() + daemon.getOutputStream.flush() + } } + } else { + simpleWorkers.get(worker).foreach(_.destroy()) } - } else { - simpleWorkers.get(worker).foreach(_.destroy()) } worker.close() } def releaseWorker(worker: Socket) { if (useDaemon && envVars.get("SPARK_REUSE_WORKER").isDefined) { - idleWorkers.synchronized { + synchronized { lastActivity = System.currentTimeMillis() idleWorkers.enqueue(worker) } @@ -302,5 +301,5 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String private object PythonWorkerFactory { val PROCESS_WAIT_TIMEOUT_MS = 10000 - val IDLE_WORKER_TIMEOUT_MS = 60000 + val IDLE_WORKER_TIMEOUT_MS = 60000 // kill idle workers after 1 minute } From 760ab1f9b0282943569de97d86c68952d47ad53f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 9 Sep 2014 17:39:35 -0700 Subject: [PATCH 10/13] do not reuse worker if there are any exceptions --- .../apache/spark/api/python/PythonRDD.scala | 6 ++++- .../api/python/PythonWorkerFactory.scala | 2 +- python/pyspark/tests.py | 24 +++++++++++++++++++ 3 files changed, 30 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 62ecf2dede3d2..ca8eef5f99edf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -73,9 +73,10 @@ private[spark] class PythonRDD( // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) + var complete_cleanly = false context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() - if (!context.isInterrupted) { + if (reuse_worker && complete_cleanly) { env.releasePythonWorker(pythonExec, envVars.toMap, worker) } else { try { @@ -141,6 +142,7 @@ private[spark] class PythonRDD( stream.readFully(update) accumulator += Collections.singletonList(update) } + complete_cleanly = true null } } catch { @@ -235,11 +237,13 @@ private[spark] class PythonRDD( } catch { case e: Exception if context.isCompleted || context.isInterrupted => logDebug("Exception thrown after task completion (likely due to cleanup)", e) + worker.shutdownOutput() case e: Exception => // We must avoid throwing exceptions here, because the thread uncaught exception handler // will kill the whole executor (see org.apache.spark.executor.Executor). _exception = e + worker.shutdownOutput() } } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 5e73747f2e8b8..71bdf0fe1b917 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -282,7 +282,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } def releaseWorker(worker: Socket) { - if (useDaemon && envVars.get("SPARK_REUSE_WORKER").isDefined) { + if (useDaemon) { synchronized { lastActivity = System.currentTimeMillis() idleWorkers.enqueue(worker) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index f1a75cbff5c19..a44592e5e1e34 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1077,11 +1077,35 @@ def run(): except OSError: self.fail("daemon had been killed") + # run a normal job + rdd = self.sc.parallelize(range(100), 1) + self.assertEqual(100, rdd.map(str).count()) + def test_fd_leak(self): N = 1100 # fd limit is 1024 by default rdd = self.sc.parallelize(range(N), N) self.assertEquals(N, rdd.count()) + def test_after_exception(self): + def raise_exception(_): + raise Exception() + rdd = self.sc.parallelize(range(100), 1) + self.assertRaises(Exception, lambda: rdd.foreach(raise_exception)) + self.assertEqual(100, rdd.map(str).count()) + + def test_after_jvm_exception(self): + tempFile = tempfile.NamedTemporaryFile(delete=False) + tempFile.write("Hello World!") + tempFile.close() + data = self.sc.textFile(tempFile.name, 1) + filtered_data = data.filter(lambda x: True) + self.assertEqual(1, filtered_data.count()) + os.unlink(tempFile.name) + self.assertRaises(Exception, lambda: filtered_data.count()) + + rdd = self.sc.parallelize(range(100), 1) + self.assertEqual(100, rdd.map(str).count()) + class TestSparkSubmit(unittest.TestCase): From 3133a60e4488402bfad95fd8d13d6f2e7bd8e888 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 9 Sep 2014 17:53:33 -0700 Subject: [PATCH 11/13] fix accumulator with reused worker --- python/pyspark/tests.py | 11 +++++++++++ python/pyspark/worker.py | 1 + 2 files changed, 12 insertions(+) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index a44592e5e1e34..76a398a952fb8 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1106,6 +1106,17 @@ def test_after_jvm_exception(self): rdd = self.sc.parallelize(range(100), 1) self.assertEqual(100, rdd.map(str).count()) + def test_accumulator_when_reuse_worker(self): + from pyspark.accumulators import INT_ACCUMULATOR_PARAM + acc1 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(range(100), 20).foreach(lambda x: acc1.add(x)) + self.assertEqual(sum(range(100)), acc1.value) + + acc2 = self.sc.accumulator(0, INT_ACCUMULATOR_PARAM) + self.sc.parallelize(range(100), 20).foreach(lambda x: acc2.add(x)) + self.assertEqual(sum(range(100)), acc2.value) + self.assertEqual(sum(range(100)), acc1.value) + class TestSparkSubmit(unittest.TestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 77254f599aa04..9acad54eb7e42 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -76,6 +76,7 @@ def main(infile, outfile): bid = - bid - 1 _broadcastRegistry.pop(bid, None) + _accumulatorRegistry.clear() command = pickleSer._read_with_length(infile) (func, deserializer, serializer) = command init_time = time.time() From cf1c55e4fe891742bd2102606b901cc834a87557 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 9 Sep 2014 17:54:35 -0700 Subject: [PATCH 12/13] address comments --- python/pyspark/worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 9acad54eb7e42..61b8a74d060e8 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -74,7 +74,7 @@ def main(infile, outfile): _broadcastRegistry[bid] = Broadcast(bid, value) else: bid = - bid - 1 - _broadcastRegistry.pop(bid, None) + _broadcastRegistry.remove(bid) _accumulatorRegistry.clear() command = pickleSer._read_with_length(infile) From 3939f207e2abeef8cf1b13a65554935beed42788 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 9 Sep 2014 23:43:55 -0700 Subject: [PATCH 13/13] fix bug in serializer in mllib --- python/pyspark/mllib/_common.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index bb60d3d0c8463..68f6033616726 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -21,7 +21,7 @@ from numpy import ndarray, float64, int64, int32, array_equal, array from pyspark import SparkContext, RDD from pyspark.mllib.linalg import SparseVector -from pyspark.serializers import Serializer +from pyspark.serializers import FramedSerializer """ @@ -451,18 +451,16 @@ def _serialize_rating(r): return ba -class RatingDeserializer(Serializer): +class RatingDeserializer(FramedSerializer): - def loads(self, stream): - length = struct.unpack("!i", stream.read(4))[0] - ba = stream.read(length) - res = ndarray(shape=(3, ), buffer=ba, dtype=float64, offset=4) + def loads(self, string): + res = ndarray(shape=(3, ), buffer=string, dtype=float64, offset=4) return int(res[0]), int(res[1]), res[2] def load_stream(self, stream): while True: try: - yield self.loads(stream) + yield self._read_with_length(stream) except struct.error: return except EOFError: