Skip to content

Commit 9517c8f

Browse files
author
Davies Liu
committed
fix memory leak in collect()
1 parent eb48fd6 commit 9517c8f

File tree

4 files changed

+28
-10
lines changed

4 files changed

+28
-10
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,23 +19,22 @@ package org.apache.spark.api.python
1919

2020
import java.io._
2121
import java.net._
22-
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections}
23-
24-
import org.apache.spark.input.PortableDataStream
22+
import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap}
2523

2624
import scala.collection.JavaConversions._
2725
import scala.collection.mutable
2826
import scala.language.existentials
2927

3028
import com.google.common.base.Charsets.UTF_8
31-
3229
import org.apache.hadoop.conf.Configuration
3330
import org.apache.hadoop.io.compress.CompressionCodec
34-
import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf}
31+
import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat}
3532
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat => NewOutputFormat}
33+
3634
import org.apache.spark._
37-
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
35+
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext}
3836
import org.apache.spark.broadcast.Broadcast
37+
import org.apache.spark.input.PortableDataStream
3938
import org.apache.spark.rdd.RDD
4039
import org.apache.spark.util.Utils
4140

@@ -358,6 +357,14 @@ private[spark] object PythonRDD extends Logging {
358357
flattenedPartition.iterator
359358
}
360359

360+
/**
361+
* A helper function to collect an RDD as an iterator, then it only export the Iterator
362+
* object to Py4j, easily be GCed.
363+
*/
364+
def collectAsIterator[T](jrdd: JavaRDD[T]): Iterator[T] = {
365+
jrdd.collect().iterator()
366+
}
367+
361368
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
362369
JavaRDD[Array[Byte]] = {
363370
val file = new DataInputStream(new FileInputStream(filename))

python/pyspark/context.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@
2121
from threading import Lock
2222
from tempfile import NamedTemporaryFile
2323

24+
from py4j.java_gateway import JavaObject
25+
from py4j.java_collections import ListConverter
26+
import py4j.protocol
27+
2428
from pyspark import accumulators
2529
from pyspark.accumulators import Accumulator
2630
from pyspark.broadcast import Broadcast
@@ -35,8 +39,6 @@
3539
from pyspark.status import StatusTracker
3640
from pyspark.profiler import ProfilerCollector, BasicProfiler
3741

38-
from py4j.java_collections import ListConverter
39-
4042

4143
__all__ = ['SparkContext']
4244

@@ -49,6 +51,15 @@
4951
}
5052

5153

54+
# The implementation in Py4j will create 'Java' member for parameter (JavaObject)
55+
# because of circular reference between JavaObject and JavaMember, then the object
56+
# can not be released after used until GC kick-in.
57+
def is_python_proxy(parameter):
58+
return not isinstance(parameter, JavaObject) and _old_is_python_proxy(parameter)
59+
_old_is_python_proxy = py4j.protocol.is_python_proxy
60+
py4j.protocol.is_python_proxy = is_python_proxy
61+
62+
5263
class SparkContext(object):
5364

5465
"""

python/pyspark/rdd.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ def collect(self):
698698
Return a list that contains all of the elements in this RDD.
699699
"""
700700
with SCCallSiteSync(self.context) as css:
701-
bytesInJava = self._jrdd.collect().iterator()
701+
bytesInJava = self.ctx._jvm.PythonRDD.collectAsIterator(self._jrdd)
702702
return list(self._collect_iterator_through_file(bytesInJava))
703703

704704
def _collect_iterator_through_file(self, iterator):

python/pyspark/sql/dataframe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def collect(self):
310310
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
311311
"""
312312
with SCCallSiteSync(self._sc) as css:
313-
bytesInJava = self._jdf.javaToPython().collect().iterator()
313+
bytesInJava = self._sc._jvm.PythonRDD.collectAsIterator(self._jdf.javaToPython())
314314
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
315315
tempFile.close()
316316
self._sc._writeToFile(bytesInJava, tempFile.name)

0 commit comments

Comments
 (0)