Skip to content

Commit b5ccf0d

Browse files
committed
[SPARK-24396][SS][PYSPARK] Add Structured Streaming ForeachWriter for python
## What changes were proposed in this pull request? This PR adds `foreach` for streaming queries in Python. Users will be able to specify their processing logic in two different ways. - As a function that takes a row as input. - As an object that has methods `open`, `process`, and `close` methods. See the python docs in this PR for more details. ## How was this patch tested? Added java and python unit tests Author: Tathagata Das <[email protected]> Closes #21477 from tdas/SPARK-24396.
1 parent 495d8cf commit b5ccf0d

File tree

8 files changed

+811
-72
lines changed

8 files changed

+811
-72
lines changed

python/pyspark/sql/streaming.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,168 @@ def trigger(self, processingTime=None, once=None, continuous=None):
854854
self._jwrite = self._jwrite.trigger(jTrigger)
855855
return self
856856

857+
@since(2.4)
858+
def foreach(self, f):
859+
"""
860+
Sets the output of the streaming query to be processed using the provided writer ``f``.
861+
This is often used to write the output of a streaming query to arbitrary storage systems.
862+
The processing logic can be specified in two ways.
863+
864+
#. A **function** that takes a row as input.
865+
This is a simple way to express your processing logic. Note that this does
866+
not allow you to deduplicate generated data when failures cause reprocessing of
867+
some input data. That would require you to specify the processing logic in the next
868+
way.
869+
870+
#. An **object** with a ``process`` method and optional ``open`` and ``close`` methods.
871+
The object can have the following methods.
872+
873+
* ``open(partition_id, epoch_id)``: *Optional* method that initializes the processing
874+
(for example, open a connection, start a transaction, etc). Additionally, you can
875+
use the `partition_id` and `epoch_id` to deduplicate regenerated data
876+
(discussed later).
877+
878+
* ``process(row)``: *Non-optional* method that processes each :class:`Row`.
879+
880+
* ``close(error)``: *Optional* method that finalizes and cleans up (for example,
881+
close connection, commit transaction, etc.) after all rows have been processed.
882+
883+
The object will be used by Spark in the following way.
884+
885+
* A single copy of this object is responsible of all the data generated by a
886+
single task in a query. In other words, one instance is responsible for
887+
processing one partition of the data generated in a distributed manner.
888+
889+
* This object must be serializable because each task will get a fresh
890+
serialized-deserialized copy of the provided object. Hence, it is strongly
891+
recommended that any initialization for writing data (e.g. opening a
892+
connection or starting a transaction) is done after the `open(...)`
893+
method has been called, which signifies that the task is ready to generate data.
894+
895+
* The lifecycle of the methods are as follows.
896+
897+
For each partition with ``partition_id``:
898+
899+
... For each batch/epoch of streaming data with ``epoch_id``:
900+
901+
....... Method ``open(partitionId, epochId)`` is called.
902+
903+
....... If ``open(...)`` returns true, for each row in the partition and
904+
batch/epoch, method ``process(row)`` is called.
905+
906+
....... Method ``close(errorOrNull)`` is called with error (if any) seen while
907+
processing rows.
908+
909+
Important points to note:
910+
911+
* The `partitionId` and `epochId` can be used to deduplicate generated data when
912+
failures cause reprocessing of some input data. This depends on the execution
913+
mode of the query. If the streaming query is being executed in the micro-batch
914+
mode, then every partition represented by a unique tuple (partition_id, epoch_id)
915+
is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used
916+
to deduplicate and/or transactionally commit data and achieve exactly-once
917+
guarantees. However, if the streaming query is being executed in the continuous
918+
mode, then this guarantee does not hold and therefore should not be used for
919+
deduplication.
920+
921+
* The ``close()`` method (if exists) will be called if `open()` method exists and
922+
returns successfully (irrespective of the return value), except if the Python
923+
crashes in the middle.
924+
925+
.. note:: Evolving.
926+
927+
>>> # Print every row using a function
928+
>>> def print_row(row):
929+
... print(row)
930+
...
931+
>>> writer = sdf.writeStream.foreach(print_row)
932+
>>> # Print every row using a object with process() method
933+
>>> class RowPrinter:
934+
... def open(self, partition_id, epoch_id):
935+
... print("Opened %d, %d" % (partition_id, epoch_id))
936+
... return True
937+
... def process(self, row):
938+
... print(row)
939+
... def close(self, error):
940+
... print("Closed with error: %s" % str(error))
941+
...
942+
>>> writer = sdf.writeStream.foreach(RowPrinter())
943+
"""
944+
945+
from pyspark.rdd import _wrap_function
946+
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
947+
from pyspark.taskcontext import TaskContext
948+
949+
if callable(f):
950+
# The provided object is a callable function that is supposed to be called on each row.
951+
# Construct a function that takes an iterator and calls the provided function on each
952+
# row.
953+
def func_without_process(_, iterator):
954+
for x in iterator:
955+
f(x)
956+
return iter([])
957+
958+
func = func_without_process
959+
960+
else:
961+
# The provided object is not a callable function. Then it is expected to have a
962+
# 'process(row)' method, and optional 'open(partition_id, epoch_id)' and
963+
# 'close(error)' methods.
964+
965+
if not hasattr(f, 'process'):
966+
raise Exception("Provided object does not have a 'process' method")
967+
968+
if not callable(getattr(f, 'process')):
969+
raise Exception("Attribute 'process' in provided object is not callable")
970+
971+
def doesMethodExist(method_name):
972+
exists = hasattr(f, method_name)
973+
if exists and not callable(getattr(f, method_name)):
974+
raise Exception(
975+
"Attribute '%s' in provided object is not callable" % method_name)
976+
return exists
977+
978+
open_exists = doesMethodExist('open')
979+
close_exists = doesMethodExist('close')
980+
981+
def func_with_open_process_close(partition_id, iterator):
982+
epoch_id = TaskContext.get().getLocalProperty('streaming.sql.batchId')
983+
if epoch_id:
984+
epoch_id = int(epoch_id)
985+
else:
986+
raise Exception("Could not get batch id from TaskContext")
987+
988+
# Check if the data should be processed
989+
should_process = True
990+
if open_exists:
991+
should_process = f.open(partition_id, epoch_id)
992+
993+
error = None
994+
995+
try:
996+
if should_process:
997+
for x in iterator:
998+
f.process(x)
999+
except Exception as ex:
1000+
error = ex
1001+
finally:
1002+
if close_exists:
1003+
f.close(error)
1004+
if error:
1005+
raise error
1006+
1007+
return iter([])
1008+
1009+
func = func_with_open_process_close
1010+
1011+
serializer = AutoBatchedSerializer(PickleSerializer())
1012+
wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer)
1013+
jForeachWriter = \
1014+
self._spark._sc._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter(
1015+
wrapped_func, self._df._jdf.schema())
1016+
self._jwrite.foreach(jForeachWriter)
1017+
return self
1018+
8571019
@ignore_unicode_prefix
8581020
@since(2.0)
8591021
def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None,

0 commit comments

Comments
 (0)