Skip to content

Commit ff85b64

Browse files
tdasotterc
authored andcommitted
[SPARK-24396][SS][PYSPARK] Add Structured Streaming ForeachWriter for python
This PR adds `foreach` for streaming queries in Python. Users will be able to specify their processing logic in two different ways. - As a function that takes a row as input. - As an object that has methods `open`, `process`, and `close` methods. See the python docs in this PR for more details. Added java and python unit tests Author: Tathagata Das <[email protected]> Closes apache#21477 from tdas/SPARK-24396. Ref: LIHADOOP-48531
1 parent 0294f6e commit ff85b64

File tree

8 files changed

+923
-59
lines changed

8 files changed

+923
-59
lines changed

python/pyspark/sql/streaming.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -837,6 +837,168 @@ def trigger(self, processingTime=None, once=None, continuous=None):
837837
self._jwrite = self._jwrite.trigger(jTrigger)
838838
return self
839839

840+
@since(2.4)
841+
def foreach(self, f):
842+
"""
843+
Sets the output of the streaming query to be processed using the provided writer ``f``.
844+
This is often used to write the output of a streaming query to arbitrary storage systems.
845+
The processing logic can be specified in two ways.
846+
847+
#. A **function** that takes a row as input.
848+
This is a simple way to express your processing logic. Note that this does
849+
not allow you to deduplicate generated data when failures cause reprocessing of
850+
some input data. That would require you to specify the processing logic in the next
851+
way.
852+
853+
#. An **object** with a ``process`` method and optional ``open`` and ``close`` methods.
854+
The object can have the following methods.
855+
856+
* ``open(partition_id, epoch_id)``: *Optional* method that initializes the processing
857+
(for example, open a connection, start a transaction, etc). Additionally, you can
858+
use the `partition_id` and `epoch_id` to deduplicate regenerated data
859+
(discussed later).
860+
861+
* ``process(row)``: *Non-optional* method that processes each :class:`Row`.
862+
863+
* ``close(error)``: *Optional* method that finalizes and cleans up (for example,
864+
close connection, commit transaction, etc.) after all rows have been processed.
865+
866+
The object will be used by Spark in the following way.
867+
868+
* A single copy of this object is responsible of all the data generated by a
869+
single task in a query. In other words, one instance is responsible for
870+
processing one partition of the data generated in a distributed manner.
871+
872+
* This object must be serializable because each task will get a fresh
873+
serialized-deserialized copy of the provided object. Hence, it is strongly
874+
recommended that any initialization for writing data (e.g. opening a
875+
connection or starting a transaction) is done after the `open(...)`
876+
method has been called, which signifies that the task is ready to generate data.
877+
878+
* The lifecycle of the methods are as follows.
879+
880+
For each partition with ``partition_id``:
881+
882+
... For each batch/epoch of streaming data with ``epoch_id``:
883+
884+
....... Method ``open(partitionId, epochId)`` is called.
885+
886+
....... If ``open(...)`` returns true, for each row in the partition and
887+
batch/epoch, method ``process(row)`` is called.
888+
889+
....... Method ``close(errorOrNull)`` is called with error (if any) seen while
890+
processing rows.
891+
892+
Important points to note:
893+
894+
* The `partitionId` and `epochId` can be used to deduplicate generated data when
895+
failures cause reprocessing of some input data. This depends on the execution
896+
mode of the query. If the streaming query is being executed in the micro-batch
897+
mode, then every partition represented by a unique tuple (partition_id, epoch_id)
898+
is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used
899+
to deduplicate and/or transactionally commit data and achieve exactly-once
900+
guarantees. However, if the streaming query is being executed in the continuous
901+
mode, then this guarantee does not hold and therefore should not be used for
902+
deduplication.
903+
904+
* The ``close()`` method (if exists) will be called if `open()` method exists and
905+
returns successfully (irrespective of the return value), except if the Python
906+
crashes in the middle.
907+
908+
.. note:: Evolving.
909+
910+
>>> # Print every row using a function
911+
>>> def print_row(row):
912+
... print(row)
913+
...
914+
>>> writer = sdf.writeStream.foreach(print_row)
915+
>>> # Print every row using a object with process() method
916+
>>> class RowPrinter:
917+
... def open(self, partition_id, epoch_id):
918+
... print("Opened %d, %d" % (partition_id, epoch_id))
919+
... return True
920+
... def process(self, row):
921+
... print(row)
922+
... def close(self, error):
923+
... print("Closed with error: %s" % str(error))
924+
...
925+
>>> writer = sdf.writeStream.foreach(RowPrinter())
926+
"""
927+
928+
from pyspark.rdd import _wrap_function
929+
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
930+
from pyspark.taskcontext import TaskContext
931+
932+
if callable(f):
933+
# The provided object is a callable function that is supposed to be called on each row.
934+
# Construct a function that takes an iterator and calls the provided function on each
935+
# row.
936+
def func_without_process(_, iterator):
937+
for x in iterator:
938+
f(x)
939+
return iter([])
940+
941+
func = func_without_process
942+
943+
else:
944+
# The provided object is not a callable function. Then it is expected to have a
945+
# 'process(row)' method, and optional 'open(partition_id, epoch_id)' and
946+
# 'close(error)' methods.
947+
948+
if not hasattr(f, 'process'):
949+
raise Exception("Provided object does not have a 'process' method")
950+
951+
if not callable(getattr(f, 'process')):
952+
raise Exception("Attribute 'process' in provided object is not callable")
953+
954+
def doesMethodExist(method_name):
955+
exists = hasattr(f, method_name)
956+
if exists and not callable(getattr(f, method_name)):
957+
raise Exception(
958+
"Attribute '%s' in provided object is not callable" % method_name)
959+
return exists
960+
961+
open_exists = doesMethodExist('open')
962+
close_exists = doesMethodExist('close')
963+
964+
def func_with_open_process_close(partition_id, iterator):
965+
epoch_id = TaskContext.get().getLocalProperty('streaming.sql.batchId')
966+
if epoch_id:
967+
epoch_id = int(epoch_id)
968+
else:
969+
raise Exception("Could not get batch id from TaskContext")
970+
971+
# Check if the data should be processed
972+
should_process = True
973+
if open_exists:
974+
should_process = f.open(partition_id, epoch_id)
975+
976+
error = None
977+
978+
try:
979+
if should_process:
980+
for x in iterator:
981+
f.process(x)
982+
except Exception as ex:
983+
error = ex
984+
finally:
985+
if close_exists:
986+
f.close(error)
987+
if error:
988+
raise error
989+
990+
return iter([])
991+
992+
func = func_with_open_process_close
993+
994+
serializer = AutoBatchedSerializer(PickleSerializer())
995+
wrapped_func = _wrap_function(self._spark._sc, func, serializer, serializer)
996+
jForeachWriter = \
997+
self._spark._sc._jvm.org.apache.spark.sql.execution.python.PythonForeachWriter(
998+
wrapped_func, self._df._jdf.schema())
999+
self._jwrite.foreach(jForeachWriter)
1000+
return self
1001+
8401002
@ignore_unicode_prefix
8411003
@since(2.0)
8421004
def start(self, path=None, format=None, outputMode=None, partitionBy=None, queryName=None,

0 commit comments

Comments
 (0)