@@ -345,6 +345,8 @@ private[spark] object PythonRDD extends Logging {
345345 * This method will serve an iterator of an array that contains all elements in the RDD
346346 * (effectively a collect()), but allows you to run on a certain subset of partitions,
347347 * or to enable local execution.
348+ *
349+ * @return the port number of a local socket which serves the data collected from this job.
348350 */
349351 def runJob (
350352 sc : SparkContext ,
@@ -356,14 +358,17 @@ private[spark] object PythonRDD extends Logging {
356358 val allPartitions : Array [UnrolledPartition ] =
357359 sc.runJob(rdd, (x : Iterator [ByteArray ]) => x.toArray, partitions, allowLocal)
358360 val flattenedPartition : UnrolledPartition = Array .concat(allPartitions : _* )
359- serveIterator(flattenedPartition.iterator)
361+ serveIterator(flattenedPartition.iterator,
362+ s " serve RDD ${rdd.id} with partitions ${partitions.mkString(" ," )}" )
360363 }
361364
362365 /**
363- * A helper function to collect an RDD as an iterator, then serve it via socket
366+ * A helper function to collect an RDD as an iterator, then serve it via socket.
367+ *
368+ * @return the port number of a local socket which serves the data collected from this job.
364369 */
365370 def collectAndServe [T ](rdd : RDD [T ]): Int = {
366- serveIterator(rdd.collect().iterator)
371+ serveIterator(rdd.collect().iterator, s " serve RDD ${rdd.id} " )
367372 }
368373
369374 def readRDDFromFile (sc : JavaSparkContext , filename : String , parallelism : Int ):
@@ -583,12 +588,24 @@ private[spark] object PythonRDD extends Logging {
583588 dataOut.write(bytes)
584589 }
585590
586- private def serveIterator [T ](items : Iterator [T ]): Int = {
591+ /**
592+ * Create a socket server and a background thread to serve the data in `items`,
593+ *
594+ * The socket server can only accept one connection, or close if no connection
595+ * in 3 seconds.
596+ *
597+ * Once a connection comes in, it tries to serialize all the data in `items`
598+ * and send them into this connection.
599+ *
600+ * The thread will terminate after all the data are sent or any exceptions happen.
601+ */
602+ private def serveIterator [T ](items : Iterator [T ], threadName : String ): Int = {
587603 val serverSocket = new ServerSocket (0 , 1 )
588604 serverSocket.setReuseAddress(true )
605+ // Close the socket if no connection in 3 seconds
589606 serverSocket.setSoTimeout(3000 )
590607
591- new Thread (" serve iterator " ) {
608+ new Thread (threadName ) {
592609 setDaemon(true )
593610 override def run () {
594611 try {
@@ -601,7 +618,7 @@ private[spark] object PythonRDD extends Logging {
601618 }
602619 } catch {
603620 case NonFatal (e) =>
604- logError(s " Error while sending iterator: $e " )
621+ logError(s " Error while sending iterator " , e )
605622 } finally {
606623 serverSocket.close()
607624 }
0 commit comments