From 9517c8f50231ed7bfdc7e4412bd0d7e5715cb600 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 5 Mar 2015 18:11:43 -0800 Subject: [PATCH 1/4] fix memory leak in collect() --- .../apache/spark/api/python/PythonRDD.scala | 19 +++++++++++++------ python/pyspark/context.py | 15 +++++++++++++-- python/pyspark/rdd.py | 2 +- python/pyspark/sql/dataframe.py | 2 +- 4 files changed, 28 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index b1cec0f6472b0..07b8308c3c1a2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -19,23 +19,22 @@ package org.apache.spark.api.python import java.io._ import java.net._ -import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections} - -import org.apache.spark.input.PortableDataStream +import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConversions._ import scala.collection.mutable import scala.language.existentials import com.google.common.base.Charsets.UTF_8 - import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec -import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf} +import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat} + import org.apache.spark._ -import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} import org.apache.spark.broadcast.Broadcast +import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils @@ -358,6 +357,14 @@ private[spark] object PythonRDD extends Logging { flattenedPartition.iterator } + /** + * A helper function to collect an RDD as an iterator, then it only export the Iterator + * object to Py4j, easily be GCed. + */ + def collectAsIterator[T](jrdd: JavaRDD[T]): Iterator[T] = { + jrdd.collect().iterator() + } + def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6011caf9f1c5a..af045b83ec51a 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -21,6 +21,10 @@ from threading import Lock from tempfile import NamedTemporaryFile +from py4j.java_gateway import JavaObject +from py4j.java_collections import ListConverter +import py4j.protocol + from pyspark import accumulators from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast @@ -35,8 +39,6 @@ from pyspark.status import StatusTracker from pyspark.profiler import ProfilerCollector, BasicProfiler -from py4j.java_collections import ListConverter - __all__ = ['SparkContext'] @@ -49,6 +51,15 @@ } +# The implementation in Py4j will create 'Java' member for parameter (JavaObject) +# because of circular reference between JavaObject and JavaMember, then the object +# can not be released after used until GC kick-in. +def is_python_proxy(parameter): + return not isinstance(parameter, JavaObject) and _old_is_python_proxy(parameter) +_old_is_python_proxy = py4j.protocol.is_python_proxy +py4j.protocol.is_python_proxy = is_python_proxy + + class SparkContext(object): """ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index cb12fed98c53d..12541e711120d 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -698,7 +698,7 @@ def collect(self): Return a list that contains all of the elements in this RDD. """ with SCCallSiteSync(self.context) as css: - bytesInJava = self._jrdd.collect().iterator() + bytesInJava = self.ctx._jvm.PythonRDD.collectAsIterator(self._jrdd) return list(self._collect_iterator_through_file(bytesInJava)) def _collect_iterator_through_file(self, iterator): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 5c3b7377c33b5..de416f4c6978c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -310,7 +310,7 @@ def collect(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - bytesInJava = self._jdf.javaToPython().collect().iterator() + bytesInJava = self._sc._jvm.PythonRDD.collectAsIterator(self._jdf.javaToPython()) tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) tempFile.close() self._sc._writeToFile(bytesInJava, tempFile.name) From ba5461492a931eeb3cf945edd7cf63f4a8f99b0f Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 5 Mar 2015 23:04:55 -0800 Subject: [PATCH 2/4] use socket to transfer data from JVM --- .../apache/spark/api/python/PythonRDD.scala | 48 +++++++++++++------ python/pyspark/context.py | 19 ++------ python/pyspark/rdd.py | 30 ++++++------ python/pyspark/sql/dataframe.py | 14 ++---- 4 files changed, 53 insertions(+), 58 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 07b8308c3c1a2..56b3d20504b7a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -38,6 +38,8 @@ import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils +import scala.util.control.NonFatal + private[spark] class PythonRDD( @transient parent: RDD[_], command: Array[Byte], @@ -340,7 +342,7 @@ private[spark] object PythonRDD extends Logging { /** * Adapter for calling SparkContext#runJob from Python. * - * This method will return an iterator of an array that contains all elements in the RDD + * This method will serve an iterator of an array that contains all elements in the RDD * (effectively a collect()), but allows you to run on a certain subset of partitions, * or to enable local execution. */ @@ -348,21 +350,20 @@ private[spark] object PythonRDD extends Logging { sc: SparkContext, rdd: JavaRDD[Array[Byte]], partitions: JArrayList[Int], - allowLocal: Boolean): Iterator[Array[Byte]] = { + allowLocal: Boolean): Int = { type ByteArray = Array[Byte] type UnrolledPartition = Array[ByteArray] val allPartitions: Array[UnrolledPartition] = sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal) val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*) - flattenedPartition.iterator + serveIterator(flattenedPartition.iterator) } /** - * A helper function to collect an RDD as an iterator, then it only export the Iterator - * object to Py4j, easily be GCed. + * A helper function to collect an RDD as an iterator, then serve it via socket */ - def collectAsIterator[T](jrdd: JavaRDD[T]): Iterator[T] = { - jrdd.collect().iterator() + def collectAndServe[T](rdd: RDD[T]): Int = { + serveIterator(rdd.collect().iterator) } def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): @@ -582,15 +583,32 @@ private[spark] object PythonRDD extends Logging { dataOut.write(bytes) } - def writeToFile[T](items: java.util.Iterator[T], filename: String) { - import scala.collection.JavaConverters._ - writeToFile(items.asScala, filename) - } + private def serveIterator[T](items: Iterator[T]): Int = { + val serverSocket = new ServerSocket(0, 1) + serverSocket.setReuseAddress(true) + serverSocket.setSoTimeout(3000) + + new Thread("serve iterator") { + setDaemon(true) + override def run() { + try { + val sock = serverSocket.accept() + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + try { + writeIteratorToStream(items, out) + } finally { + out.close() + } + } catch { + case NonFatal(e) => + logError(s"Error while sending iterator: $e") + } finally { + serverSocket.close() + } + } + }.start() - def writeToFile[T](items: Iterator[T], filename: String) { - val file = new DataOutputStream(new FileOutputStream(filename)) - writeIteratorToStream(items, file) - file.close() + serverSocket.getLocalPort } private def getMergedConf(confAsMap: java.util.HashMap[String, String], diff --git a/python/pyspark/context.py b/python/pyspark/context.py index af045b83ec51a..50b69b6ac5459 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -21,9 +21,7 @@ from threading import Lock from tempfile import NamedTemporaryFile -from py4j.java_gateway import JavaObject from py4j.java_collections import ListConverter -import py4j.protocol from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -34,7 +32,7 @@ from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \ PairDeserializer, AutoBatchedSerializer, NoOpSerializer from pyspark.storagelevel import StorageLevel -from pyspark.rdd import RDD +from pyspark.rdd import RDD, _load_from_socket from pyspark.traceback_utils import CallSite, first_spark_call from pyspark.status import StatusTracker from pyspark.profiler import ProfilerCollector, BasicProfiler @@ -51,15 +49,6 @@ } -# The implementation in Py4j will create 'Java' member for parameter (JavaObject) -# because of circular reference between JavaObject and JavaMember, then the object -# can not be released after used until GC kick-in. -def is_python_proxy(parameter): - return not isinstance(parameter, JavaObject) and _old_is_python_proxy(parameter) -_old_is_python_proxy = py4j.protocol.is_python_proxy -py4j.protocol.is_python_proxy = is_python_proxy - - class SparkContext(object): """ @@ -70,7 +59,6 @@ class SparkContext(object): _gateway = None _jvm = None - _writeToFile = None _next_accum_id = 0 _active_spark_context = None _lock = Lock() @@ -232,7 +220,6 @@ def _ensure_initialized(cls, instance=None, gateway=None): if not SparkContext._gateway: SparkContext._gateway = gateway or launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm - SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile if instance: if (SparkContext._active_spark_context and @@ -851,8 +838,8 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) - it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) - return list(mappedRDD._collect_iterator_through_file(it)) + port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) + return list(_load_from_socket(port, mappedRDD._jrdd_deserializer)) def show_profiles(self): """ Print the profile stats to stdout """ diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 12541e711120d..bf17f513c0bc3 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -19,7 +19,6 @@ from collections import defaultdict from itertools import chain, ifilter, imap import operator -import os import sys import shlex from subprocess import Popen, PIPE @@ -29,6 +28,7 @@ import heapq import bisect import random +import socket from math import sqrt, log, isinf, isnan, pow, ceil from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ @@ -111,6 +111,17 @@ def _parse_memory(s): return int(float(s[:-1]) * units[s[-1].lower()]) +def _load_from_socket(port, serializer): + sock = socket.socket() + try: + sock.connect(("localhost", port)) + rf = sock.makefile("rb", 65536) + for item in serializer.load_stream(rf): + yield item + finally: + sock.close() + + class Partitioner(object): def __init__(self, numPartitions, partitionFunc): self.numPartitions = numPartitions @@ -698,21 +709,8 @@ def collect(self): Return a list that contains all of the elements in this RDD. """ with SCCallSiteSync(self.context) as css: - bytesInJava = self.ctx._jvm.PythonRDD.collectAsIterator(self._jrdd) - return list(self._collect_iterator_through_file(bytesInJava)) - - def _collect_iterator_through_file(self, iterator): - # Transferring lots of data through Py4J can be slow because - # socket.readline() is inefficient. Instead, we'll dump the data to a - # file and read it back. - tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir) - tempFile.close() - self.ctx._writeToFile(iterator, tempFile.name) - # Read the data into Python and deserialize it: - with open(tempFile.name, 'rb') as tempFile: - for item in self._jrdd_deserializer.load_stream(tempFile): - yield item - os.unlink(tempFile.name) + port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd()) + return list(_load_from_socket(port, self._jrdd_deserializer)) def reduce(self, f): """ diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index de416f4c6978c..e8ce4547455a5 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -19,13 +19,11 @@ import itertools import warnings import random -import os -from tempfile import NamedTemporaryFile from py4j.java_collections import ListConverter, MapConverter from pyspark.context import SparkContext -from pyspark.rdd import RDD +from pyspark.rdd import RDD, _load_from_socket from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.traceback_utils import SCCallSiteSync @@ -310,14 +308,8 @@ def collect(self): [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')] """ with SCCallSiteSync(self._sc) as css: - bytesInJava = self._sc._jvm.PythonRDD.collectAsIterator(self._jdf.javaToPython()) - tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) - tempFile.close() - self._sc._writeToFile(bytesInJava, tempFile.name) - # Read the data into Python and deserialize it: - with open(tempFile.name, 'rb') as tempFile: - rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) - os.unlink(tempFile.name) + port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd()) + rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer()))) cls = _create_cls(self.schema) return [cls(r) for r in rs] From 24c92a41354c936a864a32c4dbae3fbfeeb95ce3 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 5 Mar 2015 23:15:43 -0800 Subject: [PATCH 3/4] fix style --- python/pyspark/context.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 50b69b6ac5459..78dccc40470e3 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -838,7 +838,8 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): # by runJob() in order to avoid having to pass a Python lambda into # SparkContext#runJob. mappedRDD = rdd.mapPartitions(partitionFunc) - port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) + port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, + allowLocal) return list(_load_from_socket(port, mappedRDD._jrdd_deserializer)) def show_profiles(self): From d7302864aa6eaed5fa052ccefcba19db00091a40 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 9 Mar 2015 14:35:02 -0700 Subject: [PATCH 4/4] address comments --- .../apache/spark/api/python/PythonRDD.scala | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 56b3d20504b7a..8d4a53b4ca9b0 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -345,6 +345,8 @@ private[spark] object PythonRDD extends Logging { * This method will serve an iterator of an array that contains all elements in the RDD * (effectively a collect()), but allows you to run on a certain subset of partitions, * or to enable local execution. + * + * @return the port number of a local socket which serves the data collected from this job. */ def runJob( sc: SparkContext, @@ -356,14 +358,17 @@ private[spark] object PythonRDD extends Logging { val allPartitions: Array[UnrolledPartition] = sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal) val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*) - serveIterator(flattenedPartition.iterator) + serveIterator(flattenedPartition.iterator, + s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}") } /** - * A helper function to collect an RDD as an iterator, then serve it via socket + * A helper function to collect an RDD as an iterator, then serve it via socket. + * + * @return the port number of a local socket which serves the data collected from this job. */ def collectAndServe[T](rdd: RDD[T]): Int = { - serveIterator(rdd.collect().iterator) + serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") } def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): @@ -583,12 +588,24 @@ private[spark] object PythonRDD extends Logging { dataOut.write(bytes) } - private def serveIterator[T](items: Iterator[T]): Int = { + /** + * Create a socket server and a background thread to serve the data in `items`, + * + * The socket server can only accept one connection, or close if no connection + * in 3 seconds. + * + * Once a connection comes in, it tries to serialize all the data in `items` + * and send them into this connection. + * + * The thread will terminate after all the data are sent or any exceptions happen. + */ + private def serveIterator[T](items: Iterator[T], threadName: String): Int = { val serverSocket = new ServerSocket(0, 1) serverSocket.setReuseAddress(true) + // Close the socket if no connection in 3 seconds serverSocket.setSoTimeout(3000) - new Thread("serve iterator") { + new Thread(threadName) { setDaemon(true) override def run() { try { @@ -601,7 +618,7 @@ private[spark] object PythonRDD extends Logging { } } catch { case NonFatal(e) => - logError(s"Error while sending iterator: $e") + logError(s"Error while sending iterator", e) } finally { serverSocket.close() }