Skip to content

Commit ba54614

Browse files
author
Davies Liu
committed
use socket to transfer data from JVM
1 parent 9517c8f commit ba54614

File tree

4 files changed

+53
-58
lines changed

4 files changed

+53
-58
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ import org.apache.spark.input.PortableDataStream
3838
import org.apache.spark.rdd.RDD
3939
import org.apache.spark.util.Utils
4040

41+
import scala.util.control.NonFatal
42+
4143
private[spark] class PythonRDD(
4244
@transient parent: RDD[_],
4345
command: Array[Byte],
@@ -340,29 +342,28 @@ private[spark] object PythonRDD extends Logging {
340342
/**
341343
* Adapter for calling SparkContext#runJob from Python.
342344
*
343-
* This method will return an iterator of an array that contains all elements in the RDD
345+
* This method will serve an iterator of an array that contains all elements in the RDD
344346
* (effectively a collect()), but allows you to run on a certain subset of partitions,
345347
* or to enable local execution.
346348
*/
347349
def runJob(
348350
sc: SparkContext,
349351
rdd: JavaRDD[Array[Byte]],
350352
partitions: JArrayList[Int],
351-
allowLocal: Boolean): Iterator[Array[Byte]] = {
353+
allowLocal: Boolean): Int = {
352354
type ByteArray = Array[Byte]
353355
type UnrolledPartition = Array[ByteArray]
354356
val allPartitions: Array[UnrolledPartition] =
355357
sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
356358
val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
357-
flattenedPartition.iterator
359+
serveIterator(flattenedPartition.iterator)
358360
}
359361

360362
/**
361-
* A helper function to collect an RDD as an iterator, then it only export the Iterator
362-
* object to Py4j, easily be GCed.
363+
* A helper function to collect an RDD as an iterator, then serve it via socket
363364
*/
364-
def collectAsIterator[T](jrdd: JavaRDD[T]): Iterator[T] = {
365-
jrdd.collect().iterator()
365+
def collectAndServe[T](rdd: RDD[T]): Int = {
366+
serveIterator(rdd.collect().iterator)
366367
}
367368

368369
def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
@@ -582,15 +583,32 @@ private[spark] object PythonRDD extends Logging {
582583
dataOut.write(bytes)
583584
}
584585

585-
def writeToFile[T](items: java.util.Iterator[T], filename: String) {
586-
import scala.collection.JavaConverters._
587-
writeToFile(items.asScala, filename)
588-
}
586+
private def serveIterator[T](items: Iterator[T]): Int = {
587+
val serverSocket = new ServerSocket(0, 1)
588+
serverSocket.setReuseAddress(true)
589+
serverSocket.setSoTimeout(3000)
590+
591+
new Thread("serve iterator") {
592+
setDaemon(true)
593+
override def run() {
594+
try {
595+
val sock = serverSocket.accept()
596+
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
597+
try {
598+
writeIteratorToStream(items, out)
599+
} finally {
600+
out.close()
601+
}
602+
} catch {
603+
case NonFatal(e) =>
604+
logError(s"Error while sending iterator: $e")
605+
} finally {
606+
serverSocket.close()
607+
}
608+
}
609+
}.start()
589610

590-
def writeToFile[T](items: Iterator[T], filename: String) {
591-
val file = new DataOutputStream(new FileOutputStream(filename))
592-
writeIteratorToStream(items, file)
593-
file.close()
611+
serverSocket.getLocalPort
594612
}
595613

596614
private def getMergedConf(confAsMap: java.util.HashMap[String, String],

python/pyspark/context.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,7 @@
2121
from threading import Lock
2222
from tempfile import NamedTemporaryFile
2323

24-
from py4j.java_gateway import JavaObject
2524
from py4j.java_collections import ListConverter
26-
import py4j.protocol
2725

2826
from pyspark import accumulators
2927
from pyspark.accumulators import Accumulator
@@ -34,7 +32,7 @@
3432
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
3533
PairDeserializer, AutoBatchedSerializer, NoOpSerializer
3634
from pyspark.storagelevel import StorageLevel
37-
from pyspark.rdd import RDD
35+
from pyspark.rdd import RDD, _load_from_socket
3836
from pyspark.traceback_utils import CallSite, first_spark_call
3937
from pyspark.status import StatusTracker
4038
from pyspark.profiler import ProfilerCollector, BasicProfiler
@@ -51,15 +49,6 @@
5149
}
5250

5351

54-
# The implementation in Py4j will create 'Java' member for parameter (JavaObject)
55-
# because of circular reference between JavaObject and JavaMember, then the object
56-
# can not be released after used until GC kick-in.
57-
def is_python_proxy(parameter):
58-
return not isinstance(parameter, JavaObject) and _old_is_python_proxy(parameter)
59-
_old_is_python_proxy = py4j.protocol.is_python_proxy
60-
py4j.protocol.is_python_proxy = is_python_proxy
61-
62-
6352
class SparkContext(object):
6453

6554
"""
@@ -70,7 +59,6 @@ class SparkContext(object):
7059

7160
_gateway = None
7261
_jvm = None
73-
_writeToFile = None
7462
_next_accum_id = 0
7563
_active_spark_context = None
7664
_lock = Lock()
@@ -232,7 +220,6 @@ def _ensure_initialized(cls, instance=None, gateway=None):
232220
if not SparkContext._gateway:
233221
SparkContext._gateway = gateway or launch_gateway()
234222
SparkContext._jvm = SparkContext._gateway.jvm
235-
SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
236223

237224
if instance:
238225
if (SparkContext._active_spark_context and
@@ -851,8 +838,8 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
851838
# by runJob() in order to avoid having to pass a Python lambda into
852839
# SparkContext#runJob.
853840
mappedRDD = rdd.mapPartitions(partitionFunc)
854-
it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
855-
return list(mappedRDD._collect_iterator_through_file(it))
841+
port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
842+
return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
856843

857844
def show_profiles(self):
858845
""" Print the profile stats to stdout """

python/pyspark/rdd.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from collections import defaultdict
2020
from itertools import chain, ifilter, imap
2121
import operator
22-
import os
2322
import sys
2423
import shlex
2524
from subprocess import Popen, PIPE
@@ -29,6 +28,7 @@
2928
import heapq
3029
import bisect
3130
import random
31+
import socket
3232
from math import sqrt, log, isinf, isnan, pow, ceil
3333

3434
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
@@ -111,6 +111,17 @@ def _parse_memory(s):
111111
return int(float(s[:-1]) * units[s[-1].lower()])
112112

113113

114+
def _load_from_socket(port, serializer):
115+
sock = socket.socket()
116+
try:
117+
sock.connect(("localhost", port))
118+
rf = sock.makefile("rb", 65536)
119+
for item in serializer.load_stream(rf):
120+
yield item
121+
finally:
122+
sock.close()
123+
124+
114125
class Partitioner(object):
115126
def __init__(self, numPartitions, partitionFunc):
116127
self.numPartitions = numPartitions
@@ -698,21 +709,8 @@ def collect(self):
698709
Return a list that contains all of the elements in this RDD.
699710
"""
700711
with SCCallSiteSync(self.context) as css:
701-
bytesInJava = self.ctx._jvm.PythonRDD.collectAsIterator(self._jrdd)
702-
return list(self._collect_iterator_through_file(bytesInJava))
703-
704-
def _collect_iterator_through_file(self, iterator):
705-
# Transferring lots of data through Py4J can be slow because
706-
# socket.readline() is inefficient. Instead, we'll dump the data to a
707-
# file and read it back.
708-
tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
709-
tempFile.close()
710-
self.ctx._writeToFile(iterator, tempFile.name)
711-
# Read the data into Python and deserialize it:
712-
with open(tempFile.name, 'rb') as tempFile:
713-
for item in self._jrdd_deserializer.load_stream(tempFile):
714-
yield item
715-
os.unlink(tempFile.name)
712+
port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
713+
return list(_load_from_socket(port, self._jrdd_deserializer))
716714

717715
def reduce(self, f):
718716
"""

python/pyspark/sql/dataframe.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,11 @@
1919
import itertools
2020
import warnings
2121
import random
22-
import os
23-
from tempfile import NamedTemporaryFile
2422

2523
from py4j.java_collections import ListConverter, MapConverter
2624

2725
from pyspark.context import SparkContext
28-
from pyspark.rdd import RDD
26+
from pyspark.rdd import RDD, _load_from_socket
2927
from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
3028
from pyspark.storagelevel import StorageLevel
3129
from pyspark.traceback_utils import SCCallSiteSync
@@ -310,14 +308,8 @@ def collect(self):
310308
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
311309
"""
312310
with SCCallSiteSync(self._sc) as css:
313-
bytesInJava = self._sc._jvm.PythonRDD.collectAsIterator(self._jdf.javaToPython())
314-
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
315-
tempFile.close()
316-
self._sc._writeToFile(bytesInJava, tempFile.name)
317-
# Read the data into Python and deserialize it:
318-
with open(tempFile.name, 'rb') as tempFile:
319-
rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
320-
os.unlink(tempFile.name)
311+
port = self._sc._jvm.PythonRDD.collectAndServe(self._jdf.javaToPython().rdd())
312+
rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
321313
cls = _create_cls(self.schema)
322314
return [cls(r) for r in rs]
323315

0 commit comments

Comments
 (0)