@@ -86,7 +86,7 @@ private[spark] case class PythonFunction(
8686private [spark] case class ChainedPythonFunctions (funcs : Seq [PythonFunction ])
8787
8888/** Thrown for exceptions in user Python code. */
89- private [spark] class PythonException (msg : String , cause : Exception )
89+ private [spark] class PythonException (msg : String , cause : Throwable )
9090 extends RuntimeException (msg, cause)
9191
9292/**
@@ -163,8 +163,63 @@ private[spark] object PythonRDD extends Logging {
163163 serveIterator(rdd.collect().iterator, s " serve RDD ${rdd.id}" )
164164 }
165165
166+ /**
167+ * A helper function to create a local RDD iterator and serve it via socket. Partitions are
168+ * are collected as separate jobs, by order of index. Partition data is first requested by a
169+ * non-zero integer to start a collection job. The response is prefaced by an integer with 1
170+ * meaning partition data will be served, 0 meaning the local iterator has been consumed,
171+ * and -1 meaining an error occurred during collection. This function is used by
172+ * pyspark.rdd._local_iterator_from_socket().
173+ *
174+ * @return 2-tuple (as a Java array) with the port number of a local socket which serves the
175+ * data collected from these jobs, and the secret for authentication.
176+ */
166177 def toLocalIteratorAndServe [T ](rdd : RDD [T ]): Array [Any ] = {
167- serveIterator(rdd.toLocalIterator, s " serve toLocalIterator " )
178+ val (port, secret) = SocketAuthServer .setupOneConnectionServer(
179+ authHelper, " serve toLocalIterator" ) { s =>
180+ val out = new DataOutputStream (s.getOutputStream)
181+ val in = new DataInputStream (s.getInputStream)
182+ Utils .tryWithSafeFinally {
183+
184+ // Collects a partition on each iteration
185+ val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
186+ rdd.sparkContext.runJob(rdd, (iter : Iterator [Any ]) => iter.toArray, Seq (i)).head
187+ }
188+
189+ // Read request for data and send next partition if nonzero
190+ var complete = false
191+ while (! complete && in.readInt() != 0 ) {
192+ if (collectPartitionIter.hasNext) {
193+ try {
194+ // Attempt to collect the next partition
195+ val partitionArray = collectPartitionIter.next()
196+
197+ // Send response there is a partition to read
198+ out.writeInt(1 )
199+
200+ // Write the next object and signal end of data for this iteration
201+ writeIteratorToStream(partitionArray.toIterator, out)
202+ out.writeInt(SpecialLengths .END_OF_DATA_SECTION )
203+ out.flush()
204+ } catch {
205+ case e : SparkException =>
206+ // Send response that an error occurred followed by error message
207+ out.writeInt(- 1 )
208+ writeUTF(e.getMessage, out)
209+ complete = true
210+ }
211+ } else {
212+ // Send response there are no more partitions to read and close
213+ out.writeInt(0 )
214+ complete = true
215+ }
216+ }
217+ } {
218+ out.close()
219+ in.close()
220+ }
221+ }
222+ Array (port, secret)
168223 }
169224
170225 def readRDDFromFile (
0 commit comments