@@ -19,26 +19,27 @@ package org.apache.spark.api.python
1919
2020import java .io ._
2121import java .net ._
22- import java .util .{List => JList , ArrayList => JArrayList , Map => JMap , UUID , Collections }
23-
24- import org .apache .spark .input .PortableDataStream
22+ import java .util .{Collections , ArrayList => JArrayList , List => JList , Map => JMap }
2523
2624import scala .collection .JavaConversions ._
2725import scala .collection .mutable
2826import scala .language .existentials
2927
3028import com .google .common .base .Charsets .UTF_8
31-
3229import org .apache .hadoop .conf .Configuration
3330import org .apache .hadoop .io .compress .CompressionCodec
34- import org .apache .hadoop .mapred .{InputFormat , OutputFormat , JobConf }
31+ import org .apache .hadoop .mapred .{InputFormat , JobConf , OutputFormat }
3532import org .apache .hadoop .mapreduce .{InputFormat => NewInputFormat , OutputFormat => NewOutputFormat }
33+
3634import org .apache .spark ._
37- import org .apache .spark .api .java .{JavaSparkContext , JavaPairRDD , JavaRDD }
35+ import org .apache .spark .api .java .{JavaPairRDD , JavaRDD , JavaSparkContext }
3836import org .apache .spark .broadcast .Broadcast
37+ import org .apache .spark .input .PortableDataStream
3938import org .apache .spark .rdd .RDD
4039import org .apache .spark .util .Utils
4140
41+ import scala .util .control .NonFatal
42+
4243private [spark] class PythonRDD (
4344 @ transient parent : RDD [_],
4445 command : Array [Byte ],
@@ -341,21 +342,33 @@ private[spark] object PythonRDD extends Logging {
341342 /**
342343 * Adapter for calling SparkContext#runJob from Python.
343344 *
344- * This method will return an iterator of an array that contains all elements in the RDD
345+ * This method will serve an iterator of an array that contains all elements in the RDD
345346 * (effectively a collect()), but allows you to run on a certain subset of partitions,
346347 * or to enable local execution.
348+ *
349+ * @return the port number of a local socket which serves the data collected from this job.
347350 */
348351 def runJob (
349352 sc : SparkContext ,
350353 rdd : JavaRDD [Array [Byte ]],
351354 partitions : JArrayList [Int ],
352- allowLocal : Boolean ): Iterator [ Array [ Byte ]] = {
355+ allowLocal : Boolean ): Int = {
353356 type ByteArray = Array [Byte ]
354357 type UnrolledPartition = Array [ByteArray ]
355358 val allPartitions : Array [UnrolledPartition ] =
356359 sc.runJob(rdd, (x : Iterator [ByteArray ]) => x.toArray, partitions, allowLocal)
357360 val flattenedPartition : UnrolledPartition = Array .concat(allPartitions : _* )
358- flattenedPartition.iterator
361+ serveIterator(flattenedPartition.iterator,
362+ s " serve RDD ${rdd.id} with partitions ${partitions.mkString(" ," )}" )
363+ }
364+
365+ /**
366+ * A helper function to collect an RDD as an iterator, then serve it via socket.
367+ *
368+ * @return the port number of a local socket which serves the data collected from this job.
369+ */
370+ def collectAndServe [T ](rdd : RDD [T ]): Int = {
371+ serveIterator(rdd.collect().iterator, s " serve RDD ${rdd.id}" )
359372 }
360373
361374 def readRDDFromFile (sc : JavaSparkContext , filename : String , parallelism : Int ):
@@ -575,15 +588,44 @@ private[spark] object PythonRDD extends Logging {
575588 dataOut.write(bytes)
576589 }
577590
578- def writeToFile [T ](items : java.util.Iterator [T ], filename : String ) {
579- import scala .collection .JavaConverters ._
580- writeToFile(items.asScala, filename)
581- }
591+ /**
592+ * Create a socket server and a background thread to serve the data in `items`,
593+ *
594+ * The socket server can only accept one connection, or close if no connection
595+ * in 3 seconds.
596+ *
597+ * Once a connection comes in, it tries to serialize all the data in `items`
598+ * and send them into this connection.
599+ *
600+ * The thread will terminate after all the data are sent or any exceptions happen.
601+ */
602+ private def serveIterator [T ](items : Iterator [T ], threadName : String ): Int = {
603+ val serverSocket = new ServerSocket (0 , 1 )
604+ serverSocket.setReuseAddress(true )
605+ // Close the socket if no connection in 3 seconds
606+ serverSocket.setSoTimeout(3000 )
607+
608+ new Thread (threadName) {
609+ setDaemon(true )
610+ override def run () {
611+ try {
612+ val sock = serverSocket.accept()
613+ val out = new DataOutputStream (new BufferedOutputStream (sock.getOutputStream))
614+ try {
615+ writeIteratorToStream(items, out)
616+ } finally {
617+ out.close()
618+ }
619+ } catch {
620+ case NonFatal (e) =>
621+ logError(s " Error while sending iterator " , e)
622+ } finally {
623+ serverSocket.close()
624+ }
625+ }
626+ }.start()
582627
583- def writeToFile [T ](items : Iterator [T ], filename : String ) {
584- val file = new DataOutputStream (new FileOutputStream (filename))
585- writeIteratorToStream(items, file)
586- file.close()
628+ serverSocket.getLocalPort
587629 }
588630
589631 private def getMergedConf (confAsMap : java.util.HashMap [String , String ],
0 commit comments