@@ -56,122 +56,37 @@ private[spark] class PythonRDD[T: ClassTag](
5656 val env = SparkEnv .get
5757 val worker : Socket = env.createPythonWorker(pythonExec, envVars.toMap)
5858
59- // Ensure worker socket is closed on task completion. Closing sockets is idempotent.
60- context.addOnCompleteCallback(() =>
59+ // Start a thread to feed the process input from our parent's iterator
60+ val writerThread = new WriterThread (env, worker, split, context)
61+
62+ context.addOnCompleteCallback { () =>
63+ writerThread.shutdownOnTaskCompletion()
64+
65+ // Cleanup the worker socket. This will also cause the Python worker to exit.
6166 try {
6267 worker.close()
6368 } catch {
6469 case e : Exception => logWarning(" Failed to close worker socket" , e)
6570 }
66- )
67-
68- @ volatile var readerException : Exception = null
69-
70- // Start a thread to feed the process input from our parent's iterator
71- new Thread (" stdin writer for " + pythonExec) {
72- override def run () {
73- try {
74- SparkEnv .set(env)
75- val stream = new BufferedOutputStream (worker.getOutputStream, bufferSize)
76- val dataOut = new DataOutputStream (stream)
77- // Partition index
78- dataOut.writeInt(split.index)
79- // sparkFilesDir
80- PythonRDD .writeUTF(SparkFiles .getRootDirectory, dataOut)
81- // Broadcast variables
82- dataOut.writeInt(broadcastVars.length)
83- for (broadcast <- broadcastVars) {
84- dataOut.writeLong(broadcast.id)
85- dataOut.writeInt(broadcast.value.length)
86- dataOut.write(broadcast.value)
87- }
88- // Python includes (*.zip and *.egg files)
89- dataOut.writeInt(pythonIncludes.length)
90- for (include <- pythonIncludes) {
91- PythonRDD .writeUTF(include, dataOut)
92- }
93- dataOut.flush()
94- // Serialized command:
95- dataOut.writeInt(command.length)
96- dataOut.write(command)
97- // Data values
98- PythonRDD .writeIteratorToStream(parent.iterator(split, context), dataOut)
99- dataOut.flush()
100- worker.shutdownOutput()
101- } catch {
102-
103- case e : java.io.FileNotFoundException =>
104- readerException = e
105- Try (worker.shutdownOutput()) // kill Python worker process
106-
107- case e : IOException =>
108- // This can happen for legitimate reasons if the Python code stops returning data
109- // before we are done passing elements through, e.g., for take(). Just log a message to
110- // say it happened (as it could also be hiding a real IOException from a data source).
111- logInfo(" stdin writer to Python finished early (may not be an error)" , e)
112-
113- case e : Exception =>
114- // We must avoid throwing exceptions here, because the thread uncaught exception handler
115- // will kill the whole executor (see Executor).
116- readerException = e
117- Try (worker.shutdownOutput()) // kill Python worker process
118- }
119- }
120- }.start()
121-
122- // Necessary to distinguish between a task that has failed and a task that is finished
123- @ volatile var complete : Boolean = false
124-
125- // It is necessary to have a monitor thread for python workers if the user cancels with
126- // interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
127- // threads can block indefinitely.
128- new Thread (s " Worker Monitor for $pythonExec" ) {
129- override def run () {
130- // Kill the worker if it is interrupted or completed
131- // When a python task completes, the context is always set to interupted
132- while (! context.interrupted) {
133- Thread .sleep(2000 )
134- }
135- if (! complete) {
136- try {
137- logWarning(" Incomplete task interrupted: Attempting to kill Python Worker" )
138- env.destroyPythonWorker(pythonExec, envVars.toMap)
139- } catch {
140- case e : Exception =>
141- logError(" Exception when trying to kill worker" , e)
142- }
143- }
144- }
145- }.start()
146-
147- /*
148- * Partial fix for SPARK-1019: Attempts to stop reading the input stream since
149- * other completion callbacks might invalidate the input. Because interruption
150- * is not synchronous this still leaves a potential race where the interruption is
151- * processed only after the stream becomes invalid.
152- */
153- context.addOnCompleteCallback{ () =>
154- complete = true // Indicate that the task has completed successfully
155- context.interrupted = true
15671 }
15772
73+ writerThread.start()
74+ new MonitorThread (env, worker, context).start()
75+
15876 // Return an iterator that read lines from the process's stdout
15977 val stream = new DataInputStream (new BufferedInputStream (worker.getInputStream, bufferSize))
16078 val stdoutIterator = new Iterator [Array [Byte ]] {
16179 def next (): Array [Byte ] = {
16280 val obj = _nextObj
16381 if (hasNext) {
164- // FIXME: can deadlock if worker is waiting for us to
165- // respond to current message (currently irrelevant because
166- // output is shutdown before we read any input)
16782 _nextObj = read()
16883 }
16984 obj
17085 }
17186
17287 private def read (): Array [Byte ] = {
173- if (readerException != null ) {
174- throw readerException
88+ if (writerThread.exception.isDefined ) {
89+ throw writerThread.exception.get
17590 }
17691 try {
17792 stream.readInt() match {
@@ -190,13 +105,14 @@ private[spark] class PythonRDD[T: ClassTag](
190105 val total = finishTime - startTime
191106 logInfo(" Times: total = %s, boot = %s, init = %s, finish = %s" .format(total, boot,
192107 init, finish))
193- read
108+ read()
194109 case SpecialLengths .PYTHON_EXCEPTION_THROWN =>
195110 // Signals that an exception has been thrown in python
196111 val exLength = stream.readInt()
197112 val obj = new Array [Byte ](exLength)
198113 stream.readFully(obj)
199- throw new PythonException (new String (obj, " utf-8" ), readerException)
114+ throw new PythonException (new String (obj, " utf-8" ),
115+ writerThread.exception.getOrElse(null ))
200116 case SpecialLengths .END_OF_DATA_SECTION =>
201117 // We've finished the data section of the output, but we can still
202118 // read some accumulator updates:
@@ -210,10 +126,15 @@ private[spark] class PythonRDD[T: ClassTag](
210126 Array .empty[Byte ]
211127 }
212128 } catch {
213- case e : Exception if readerException != null =>
129+
130+ case e : Exception if context.interrupted =>
131+ logDebug(" Exception thrown after task interruption" , e)
132+ throw new TaskKilledException
133+
134+ case e : Exception if writerThread.exception.isDefined =>
214135 logError(" Python worker exited unexpectedly (crashed)" , e)
215- logError(" Python crash may have been caused by prior exception:" , readerException )
216- throw readerException
136+ logError(" This may have been caused by a prior exception:" , writerThread.exception.get )
137+ throw writerThread.exception.get
217138
218139 case eof : EOFException =>
219140 throw new SparkException (" Python worker exited unexpectedly (crashed)" , eof)
@@ -224,10 +145,100 @@ private[spark] class PythonRDD[T: ClassTag](
224145
225146 def hasNext = _nextObj.length != 0
226147 }
227- stdoutIterator
148+ new InterruptibleIterator (context, stdoutIterator)
228149 }
229150
230151 val asJavaRDD : JavaRDD [Array [Byte ]] = JavaRDD .fromRDD(this )
152+
153+ /**
154+ * The thread responsible for writing the data from the PythonRDD's parent iterator to the
155+ * Python process.
156+ */
157+ class WriterThread (env : SparkEnv , worker : Socket , split : Partition , context : TaskContext )
158+ extends Thread (s " stdout writer for $pythonExec" ) {
159+
160+ @ volatile private var _exception : Exception = null
161+
162+ setDaemon(true )
163+
164+ /** Contains the exception thrown while writing the parent iterator to the Python process. */
165+ def exception : Option [Exception ] = Option (_exception)
166+
167+ /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
168+ def shutdownOnTaskCompletion () {
169+ assert(context.completed)
170+ this .interrupt()
171+ }
172+
173+ override def run () {
174+ try {
175+ SparkEnv .set(env)
176+ val stream = new BufferedOutputStream (worker.getOutputStream, bufferSize)
177+ val dataOut = new DataOutputStream (stream)
178+ // Partition index
179+ dataOut.writeInt(split.index)
180+ // sparkFilesDir
181+ PythonRDD .writeUTF(SparkFiles .getRootDirectory, dataOut)
182+ // Broadcast variables
183+ dataOut.writeInt(broadcastVars.length)
184+ for (broadcast <- broadcastVars) {
185+ dataOut.writeLong(broadcast.id)
186+ dataOut.writeInt(broadcast.value.length)
187+ dataOut.write(broadcast.value)
188+ }
189+ // Python includes (*.zip and *.egg files)
190+ dataOut.writeInt(pythonIncludes.length)
191+ for (include <- pythonIncludes) {
192+ PythonRDD .writeUTF(include, dataOut)
193+ }
194+ dataOut.flush()
195+ // Serialized command:
196+ dataOut.writeInt(command.length)
197+ dataOut.write(command)
198+ // Data values
199+ PythonRDD .writeIteratorToStream(parent.iterator(split, context), dataOut)
200+ dataOut.flush()
201+ } catch {
202+ case e : Exception if context.completed || context.interrupted =>
203+ logDebug(" Exception thrown after task completion (likely due to cleanup)" , e)
204+
205+ case e : Exception =>
206+ // We must avoid throwing exceptions here, because the thread uncaught exception handler
207+ // will kill the whole executor (see org.apache.spark.executor.Executor).
208+ _exception = e
209+ } finally {
210+ Try (worker.shutdownOutput()) // kill Python worker process
211+ }
212+ }
213+ }
214+
215+ /**
216+ * It is necessary to have a monitor thread for python workers if the user cancels with
217+ * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the
218+ * threads can block indefinitely.
219+ */
220+ class MonitorThread (env : SparkEnv , worker : Socket , context : TaskContext )
221+ extends Thread (s " Worker Monitor for $pythonExec" ) {
222+
223+ setDaemon(true )
224+
225+ override def run () {
226+ // Kill the worker if it is interrupted, checking until task completion.
227+ // TODO: This has a race condition if interruption occurs, as completed may still become true.
228+ while (! context.interrupted && ! context.completed) {
229+ Thread .sleep(2000 )
230+ }
231+ if (! context.completed) {
232+ try {
233+ logWarning(" Incomplete task interrupted: Attempting to kill Python Worker" )
234+ env.destroyPythonWorker(pythonExec, envVars.toMap)
235+ } catch {
236+ case e : Exception =>
237+ logError(" Exception when trying to kill worker" , e)
238+ }
239+ }
240+ }
241+ }
231242}
232243
233244/** Thrown for exceptions in user Python code. */
0 commit comments