diff --git a/python/pyspark/sql/streaming/listener.py b/python/pyspark/sql/streaming/listener.py index c1c9dce047319..2aa63cdb91ab6 100644 --- a/python/pyspark/sql/streaming/listener.py +++ b/python/pyspark/sql/streaming/listener.py @@ -64,16 +64,19 @@ class StreamingQueryListener(ABC): """ def _set_spark_session( - self, spark: "SparkSession" # type: ignore[name-defined] # noqa: F821 + self, session: "SparkSession" # type: ignore[name-defined] # noqa: F821 ) -> None: - self._sparkSession = spark + if self.spark is None: + self.spark = session @property def spark(self) -> Optional["SparkSession"]: # type: ignore[name-defined] # noqa: F821 - if hasattr(self, "_sparkSession"): - return self._sparkSession - else: - return None + return getattr(self, "_sparkSession", None) + + @spark.setter + def spark(self, session: "SparkSession") -> None: # type: ignore[name-defined] # noqa: F821 + # For backward compatibility + self._sparkSession = session def _init_listener_id(self) -> None: self._id = str(uuid.uuid4()) diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py b/python/pyspark/sql/tests/streaming/test_streaming_listener.py index 15f5575d36479..762fc335b56ad 100644 --- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py +++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py @@ -592,6 +592,23 @@ def test_streaming_query_progress_fromJson(self): self.assertEqual(sink.numOutputRows, -1) self.assertEqual(sink.metrics, {}) + def test_spark_property_in_listener(self): + # SPARK-48560: Make StreamingQueryListener.spark settable + class TestListener(StreamingQueryListener): + def __init__(self, session): + self.spark = session + + def onQueryStarted(self, event): + pass + + def onQueryProgress(self, event): + pass + + def onQueryTerminated(self, event): + pass + + self.assertEqual(TestListener(self.spark).spark, self.spark) + if __name__ == "__main__": import unittest