Skip to content

Commit b6190a3

Browse files
WweiLHyukjinKwon
authored andcommitted
[SPARK-45056][PYTHON][SS][CONNECT] Termination tests for streamingQueryListener and foreachBatch
### What changes were proposed in this pull request? Add termination tests for StreamingQueryListener and foreachBatch. The behavior is mimicked by creating the same query on server side that would have been created if running the same python query is ran on client side. For example, in foreachBatch, a python foreachBatch function is serialized using cloudPickleSerializer and passed to the server side, here we start another python process on the server and call the same cloudPickleSerializer and pass the bytes to the server, and construct `SimplePythonFunction` accordingly. Refactored the code a bit for testing purpose. ### Why are the changes needed? Necessary tests ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Test only addition ### Was this patch authored or co-authored using generative AI tooling? No Closes #42779 from WweiL/SPARK-44435-followup-termination-tests. Lead-authored-by: Wei Liu <[email protected]> Co-authored-by: Wei Liu <[email protected]> Signed-off-by: Hyukjin Kwon <[email protected]>
1 parent 21c27d5 commit b6190a3

File tree

9 files changed

+242
-18
lines changed

9 files changed

+242
-18
lines changed

.github/workflows/build_and_test.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -256,14 +256,14 @@ jobs:
256256
# We should install one Python that is higher then 3+ for SQL and Yarn because:
257257
# - SQL component also has Python related tests, for example, IntegratedUDFTestUtils.
258258
# - Yarn has a Python specific test too, for example, YarnClusterSuite.
259-
if: contains(matrix.modules, 'yarn') || (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-'))
259+
if: contains(matrix.modules, 'yarn') || (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect')
260260
with:
261261
python-version: 3.8
262262
architecture: x64
263263
- name: Install Python packages (Python 3.8)
264-
if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-'))
264+
if: (contains(matrix.modules, 'sql') && !contains(matrix.modules, 'sql-')) || contains(matrix.modules, 'connect')
265265
run: |
266-
python3.8 -m pip install 'numpy>=1.20.0' 'pyarrow==12.0.1' pandas scipy unittest-xml-reporting 'grpcio==1.56.0' 'protobuf==3.20.3'
266+
python3.8 -m pip install 'numpy>=1.20.0' 'pyarrow==12.0.1' pandas scipy unittest-xml-reporting 'grpcio>=1.48,<1.57' 'grpcio-status>=1.48,<1.57' 'protobuf==3.20.3'
267267
python3.8 -m pip list
268268
# Run the tests.
269269
- name: Run tests

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3131,8 +3131,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
31313131
val listener = if (command.getAddListener.hasPythonListenerPayload) {
31323132
new PythonStreamingQueryListener(
31333133
transformPythonFunction(command.getAddListener.getPythonListenerPayload),
3134-
sessionHolder,
3135-
pythonExec)
3134+
sessionHolder)
31363135
} else {
31373136
val listenerPacket = Utils
31383137
.deserialize[StreamingListenerPacket](

connector/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,13 @@ import org.apache.spark.sql.streaming.StreamingQueryListener
2626
* instance of this class starts a python process, inside which has the python handling logic.
2727
* When a new event is received, it is serialized to json, and passed to the python process.
2828
*/
29-
class PythonStreamingQueryListener(
30-
listener: SimplePythonFunction,
31-
sessionHolder: SessionHolder,
32-
pythonExec: String)
29+
class PythonStreamingQueryListener(listener: SimplePythonFunction, sessionHolder: SessionHolder)
3330
extends StreamingQueryListener {
3431

3532
private val port = SparkConnectService.localPort
3633
private val connectUrl = s"sc://localhost:$port/;user_id=${sessionHolder.userId}"
37-
private val runner = StreamingPythonRunner(
34+
// Scoped for testing
35+
private[connect] val runner = StreamingPythonRunner(
3836
listener,
3937
connectUrl,
4038
sessionHolder.sessionId,

connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHodlerSuite.scala

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,20 @@
1717

1818
package org.apache.spark.sql.connect.service
1919

20+
import java.nio.charset.StandardCharsets
21+
import java.nio.file.Files
22+
23+
import scala.collection.JavaConverters._
24+
import scala.collection.mutable
25+
import scala.sys.process.Process
26+
27+
import com.google.common.collect.Lists
28+
import org.scalatest.time.SpanSugar._
29+
30+
import org.apache.spark.api.python.SimplePythonFunction
31+
import org.apache.spark.sql.IntegratedUDFTestUtils
2032
import org.apache.spark.sql.connect.common.InvalidPlanInput
33+
import org.apache.spark.sql.connect.planner.{PythonStreamingQueryListener, StreamingForeachBatchHelper}
2134
import org.apache.spark.sql.test.SharedSparkSession
2235

2336
class SparkConnectSessionHolderSuite extends SharedSparkSession {
@@ -79,4 +92,196 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession {
7992
sessionHolder.getDataFrameOrThrow(key1)
8093
}
8194
}
95+
96+
private def streamingForeachBatchFunction(pysparkPythonPath: String): Array[Byte] = {
97+
var binaryFunc: Array[Byte] = null
98+
withTempPath { path =>
99+
Process(
100+
Seq(
101+
IntegratedUDFTestUtils.pythonExec,
102+
"-c",
103+
"from pyspark.serializers import CloudPickleSerializer; " +
104+
s"f = open('$path', 'wb');" +
105+
"f.write(CloudPickleSerializer().dumps((" +
106+
"lambda df, batchId: batchId)))"),
107+
None,
108+
"PYTHONPATH" -> pysparkPythonPath).!!
109+
binaryFunc = Files.readAllBytes(path.toPath)
110+
}
111+
assert(binaryFunc != null)
112+
binaryFunc
113+
}
114+
115+
private def streamingQueryListenerFunction(pysparkPythonPath: String): Array[Byte] = {
116+
var binaryFunc: Array[Byte] = null
117+
val pythonScript =
118+
"""
119+
|from pyspark.sql.streaming.listener import StreamingQueryListener
120+
|
121+
|class MyListener(StreamingQueryListener):
122+
| def onQueryStarted(e):
123+
| pass
124+
|
125+
| def onQueryIdle(e):
126+
| pass
127+
|
128+
| def onQueryProgress(e):
129+
| pass
130+
|
131+
| def onQueryTerminated(e):
132+
| pass
133+
|
134+
|listener = MyListener()
135+
""".stripMargin
136+
withTempPath { codePath =>
137+
Files.write(codePath.toPath, pythonScript.getBytes(StandardCharsets.UTF_8))
138+
withTempPath { path =>
139+
Process(
140+
Seq(
141+
IntegratedUDFTestUtils.pythonExec,
142+
"-c",
143+
"from pyspark.serializers import CloudPickleSerializer; " +
144+
s"f = open('$path', 'wb');" +
145+
s"exec(open('$codePath', 'r').read());" +
146+
"f.write(CloudPickleSerializer().dumps(listener))"),
147+
None,
148+
"PYTHONPATH" -> pysparkPythonPath).!!
149+
binaryFunc = Files.readAllBytes(path.toPath)
150+
}
151+
}
152+
assert(binaryFunc != null)
153+
binaryFunc
154+
}
155+
156+
private def dummyPythonFunction(sessionHolder: SessionHolder)(
157+
fcn: String => Array[Byte]): SimplePythonFunction = {
158+
val sparkPythonPath =
159+
s"${IntegratedUDFTestUtils.pysparkPythonPath}:${IntegratedUDFTestUtils.pythonPath}"
160+
161+
SimplePythonFunction(
162+
command = fcn(sparkPythonPath),
163+
envVars = mutable.Map("PYTHONPATH" -> sparkPythonPath).asJava,
164+
pythonIncludes = sessionHolder.artifactManager.getSparkConnectPythonIncludes.asJava,
165+
pythonExec = IntegratedUDFTestUtils.pythonExec,
166+
pythonVer = IntegratedUDFTestUtils.pythonVer,
167+
broadcastVars = Lists.newArrayList(),
168+
accumulator = null)
169+
}
170+
171+
test("python foreachBatch process: process terminates after query is stopped") {
172+
// scalastyle:off assume
173+
assume(IntegratedUDFTestUtils.shouldTestPythonUDFs)
174+
// scalastyle:on assume
175+
176+
val sessionHolder = SessionHolder.forTesting(spark)
177+
try {
178+
SparkConnectService.start(spark.sparkContext)
179+
180+
val pythonFn = dummyPythonFunction(sessionHolder)(streamingForeachBatchFunction)
181+
val (fn1, cleaner1) =
182+
StreamingForeachBatchHelper.pythonForeachBatchWrapper(pythonFn, sessionHolder)
183+
val (fn2, cleaner2) =
184+
StreamingForeachBatchHelper.pythonForeachBatchWrapper(pythonFn, sessionHolder)
185+
186+
val query1 = spark.readStream
187+
.format("rate")
188+
.load()
189+
.writeStream
190+
.format("memory")
191+
.queryName("foreachBatch_termination_test_q1")
192+
.foreachBatch(fn1)
193+
.start()
194+
195+
val query2 = spark.readStream
196+
.format("rate")
197+
.load()
198+
.writeStream
199+
.format("memory")
200+
.queryName("foreachBatch_termination_test_q2")
201+
.foreachBatch(fn2)
202+
.start()
203+
204+
sessionHolder.streamingForeachBatchRunnerCleanerCache
205+
.registerCleanerForQuery(query1, cleaner1)
206+
sessionHolder.streamingForeachBatchRunnerCleanerCache
207+
.registerCleanerForQuery(query2, cleaner2)
208+
209+
val (runner1, runner2) = (cleaner1.runner, cleaner2.runner)
210+
211+
// assert both python processes are running
212+
assert(!runner1.isWorkerStopped().get)
213+
assert(!runner2.isWorkerStopped().get)
214+
// stop query1
215+
query1.stop()
216+
// assert query1's python process is not running
217+
eventually(timeout(30.seconds)) {
218+
assert(runner1.isWorkerStopped().get)
219+
assert(!runner2.isWorkerStopped().get)
220+
}
221+
222+
// stop query2
223+
query2.stop()
224+
eventually(timeout(30.seconds)) {
225+
// assert query2's python process is not running
226+
assert(runner2.isWorkerStopped().get)
227+
}
228+
229+
assert(spark.streams.active.isEmpty) // no running query
230+
assert(spark.streams.listListeners().length == 1) // only process termination listener
231+
} finally {
232+
SparkConnectService.stop()
233+
// remove process termination listener
234+
spark.streams.removeListener(spark.streams.listListeners()(0))
235+
}
236+
}
237+
238+
test("python listener process: process terminates after listener is removed") {
239+
// scalastyle:off assume
240+
assume(IntegratedUDFTestUtils.shouldTestPythonUDFs)
241+
// scalastyle:on assume
242+
243+
val sessionHolder = SessionHolder.forTesting(spark)
244+
try {
245+
SparkConnectService.start(spark.sparkContext)
246+
247+
val pythonFn = dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction)
248+
249+
val id1 = "listener_removeListener_test_1"
250+
val id2 = "listener_removeListener_test_2"
251+
val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
252+
val listener2 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
253+
254+
sessionHolder.cacheListenerById(id1, listener1)
255+
spark.streams.addListener(listener1)
256+
sessionHolder.cacheListenerById(id2, listener2)
257+
spark.streams.addListener(listener2)
258+
259+
val (runner1, runner2) = (listener1.runner, listener2.runner)
260+
261+
// assert both python processes are running
262+
assert(!runner1.isWorkerStopped().get)
263+
assert(!runner2.isWorkerStopped().get)
264+
265+
// remove listener1
266+
spark.streams.removeListener(listener1)
267+
sessionHolder.removeCachedListener(id1)
268+
// assert listener1's python process is not running
269+
eventually(timeout(30.seconds)) {
270+
assert(runner1.isWorkerStopped().get)
271+
assert(!runner2.isWorkerStopped().get)
272+
}
273+
274+
// remove listener2
275+
spark.streams.removeListener(listener2)
276+
sessionHolder.removeCachedListener(id2)
277+
eventually(timeout(30.seconds)) {
278+
// assert listener2's python process is not running
279+
assert(runner2.isWorkerStopped().get)
280+
// all listeners are removed
281+
assert(spark.streams.listListeners().isEmpty)
282+
}
283+
} finally {
284+
SparkConnectService.stop()
285+
}
286+
}
82287
}

core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,11 @@ private[spark] class PythonWorkerFactory(
413413
}
414414
}
415415
}
416+
417+
def isWorkerStopped(worker: PythonWorker): Boolean = {
418+
assert(!useDaemon, "isWorkerStopped() is not supported for daemon mode")
419+
simpleWorkers.get(worker).exists(!_.isAlive)
420+
}
416421
}
417422

418423
private[spark] object PythonWorkerFactory {

core/src/main/scala/org/apache/spark/api/python/StreamingPythonRunner.scala

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,10 +108,30 @@ private[spark] class StreamingPythonRunner(
108108
* Stops the Python worker.
109109
*/
110110
def stop(): Unit = {
111-
pythonWorker.foreach { worker =>
111+
logInfo(s"Stopping streaming runner for sessionId: $sessionId, module: $workerModule.")
112+
113+
try {
112114
pythonWorkerFactory.foreach { factory =>
113-
factory.stopWorker(worker)
114-
factory.stop()
115+
pythonWorker.foreach { worker =>
116+
factory.stopWorker(worker)
117+
factory.stop()
118+
}
119+
}
120+
} catch {
121+
case e: Exception =>
122+
logError("Exception when trying to kill worker", e)
123+
}
124+
}
125+
126+
/**
127+
* Returns whether the Python worker has been stopped.
128+
* @return Some(true) if the Python worker has been stopped.
129+
* None if either the Python worker or the Python worker factory is not initialized.
130+
*/
131+
def isWorkerStopped(): Option[Boolean] = {
132+
pythonWorkerFactory.flatMap { factory =>
133+
pythonWorker.map { worker =>
134+
factory.isWorkerStopped(worker)
115135
}
116136
}
117137
}

python/pyspark/sql/tests/connect/streaming/test_parity_foreach_batch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def test_streaming_foreach_batch_propagates_python_errors(self):
3131
def test_streaming_foreach_batch_graceful_stop(self):
3232
super().test_streaming_foreach_batch_graceful_stop()
3333

34-
# class StreamingForeachBatchParityTests(ReusedConnectTestCase):
3534
def test_accessing_spark_session(self):
3635
spark = self.spark
3736

python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,6 @@ def func(df: DataFrame, batch_id: int):
135135
df = df.union(df)
136136
self.assertEqual(sorted(df.collect()), sorted(actual.collect()))
137137

138-
# write to delta table?
139-
140138
@staticmethod
141139
def my_test_function_2():
142140
return 2

sql/core/src/test/scala/org/apache/spark/sql/IntegratedUDFTestUtils.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,14 +97,14 @@ import org.apache.spark.sql.types.{DataType, IntegerType, NullType, StringType,
9797
object IntegratedUDFTestUtils extends SQLHelper {
9898
import scala.sys.process._
9999

100-
private lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "")
100+
private[spark] lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "")
101101

102102
// Note that we will directly refer pyspark's source, not the zip from a regular build.
103103
// It is possible the test is being ran without the build.
104104
private lazy val sourcePath = Paths.get(sparkHome, "python").toAbsolutePath
105105
private lazy val py4jPath = Paths.get(
106106
sparkHome, "python", "lib", PythonUtils.PY4J_ZIP_NAME).toAbsolutePath
107-
private lazy val pysparkPythonPath = s"$py4jPath:$sourcePath"
107+
private[spark] lazy val pysparkPythonPath = s"$py4jPath:$sourcePath"
108108

109109
private lazy val isPythonAvailable: Boolean = TestUtils.testCommandAvailable(pythonExec)
110110

0 commit comments

Comments
 (0)