@@ -854,6 +854,168 @@ def trigger(self, processingTime=None, once=None, continuous=None):
854854 self ._jwrite = self ._jwrite .trigger (jTrigger )
855855 return self
856856
857+ @since (2.4 )
858+ def foreach (self , f ):
859+ """
860+ Sets the output of the streaming query to be processed using the provided writer ``f``.
861+ This is often used to write the output of a streaming query to arbitrary storage systems.
862+ The processing logic can be specified in two ways.
863+
864+ #. A **function** that takes a row as input.
865+ This is a simple way to express your processing logic. Note that this does
866+ not allow you to deduplicate generated data when failures cause reprocessing of
867+ some input data. That would require you to specify the processing logic in the next
868+ way.
869+
870+ #. An **object** with a ``process`` method and optional ``open`` and ``close`` methods.
871+ The object can have the following methods.
872+
873+ * ``open(partition_id, epoch_id)``: *Optional* method that initializes the processing
874+ (for example, open a connection, start a transaction, etc). Additionally, you can
875+ use the `partition_id` and `epoch_id` to deduplicate regenerated data
876+ (discussed later).
877+
878+ * ``process(row)``: *Non-optional* method that processes each :class:`Row`.
879+
880+ * ``close(error)``: *Optional* method that finalizes and cleans up (for example,
881+ close connection, commit transaction, etc.) after all rows have been processed.
882+
883+ The object will be used by Spark in the following way.
884+
885+ * A single copy of this object is responsible of all the data generated by a
886+ single task in a query. In other words, one instance is responsible for
887+ processing one partition of the data generated in a distributed manner.
888+
889+ * This object must be serializable because each task will get a fresh
890+ serialized-deserialized copy of the provided object. Hence, it is strongly
891+ recommended that any initialization for writing data (e.g. opening a
892+ connection or starting a transaction) is done after the `open(...)`
893+ method has been called, which signifies that the task is ready to generate data.
894+
895+ * The lifecycle of the methods are as follows.
896+
897+ For each partition with ``partition_id``:
898+
899+ ... For each batch/epoch of streaming data with ``epoch_id``:
900+
901+ ....... Method ``open(partitionId, epochId)`` is called.
902+
903+ ....... If ``open(...)`` returns true, for each row in the partition and
904+ batch/epoch, method ``process(row)`` is called.
905+
906+ ....... Method ``close(errorOrNull)`` is called with error (if any) seen while
907+ processing rows.
908+
909+ Important points to note:
910+
911+ * The `partitionId` and `epochId` can be used to deduplicate generated data when
912+ failures cause reprocessing of some input data. This depends on the execution
913+ mode of the query. If the streaming query is being executed in the micro-batch
914+ mode, then every partition represented by a unique tuple (partition_id, epoch_id)
915+ is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used
916+ to deduplicate and/or transactionally commit data and achieve exactly-once
917+ guarantees. However, if the streaming query is being executed in the continuous
918+ mode, then this guarantee does not hold and therefore should not be used for
919+ deduplication.
920+
921+ * The ``close()`` method (if exists) will be called if `open()` method exists and
922+ returns successfully (irrespective of the return value), except if the Python
923+ crashes in the middle.
924+
925+ .. note:: Evolving.
926+
927+ >>> # Print every row using a function
928+ >>> def print_row(row):
929+ ... print(row)
930+ ...
931+ >>> writer = sdf.writeStream.foreach(print_row)
932+ >>> # Print every row using a object with process() method
933+ >>> class RowPrinter:
934+ ... def open(self, partition_id, epoch_id):
935+ ... print("Opened %d, %d" % (partition_id, epoch_id))
936+ ... return True
937+ ... def process(self, row):
938+ ... print(row)
939+ ... def close(self, error):
940+ ... print("Closed with error: %s" % str(error))
941+ ...
942+ >>> writer = sdf.writeStream.foreach(RowPrinter())
943+ """
944+
945+ from pyspark .rdd import _wrap_function
946+ from pyspark .serializers import PickleSerializer , AutoBatchedSerializer
947+ from pyspark .taskcontext import TaskContext
948+
949+ if callable (f ):
950+ # The provided object is a callable function that is supposed to be called on each row.
951+ # Construct a function that takes an iterator and calls the provided function on each
952+ # row.
953+ def func_without_process (_ , iterator ):
954+ for x in iterator :
955+ f (x )
956+ return iter ([])
957+
958+ func = func_without_process
959+
960+ else :
961+ # The provided object is not a callable function. Then it is expected to have a
962+ # 'process(row)' method, and optional 'open(partition_id, epoch_id)' and
963+ # 'close(error)' methods.
964+
965+ if not hasattr (f , 'process' ):
966+ raise Exception ("Provided object does not have a 'process' method" )
967+
968+ if not callable (getattr (f , 'process' )):
969+ raise Exception ("Attribute 'process' in provided object is not callable" )
970+
971+ def doesMethodExist (method_name ):
972+ exists = hasattr (f , method_name )
973+ if exists and not callable (getattr (f , method_name )):
974+ raise Exception (
975+ "Attribute '%s' in provided object is not callable" % method_name )
976+ return exists
977+
978+ open_exists = doesMethodExist ('open' )
979+ close_exists = doesMethodExist ('close' )
980+
981+ def func_with_open_process_close (partition_id , iterator ):
982+ epoch_id = TaskContext .get ().getLocalProperty ('streaming.sql.batchId' )
983+ if epoch_id :
984+ epoch_id = int (epoch_id )
985+ else :
986+ raise Exception ("Could not get batch id from TaskContext" )
987+
988+ # Check if the data should be processed
989+ should_process = True
990+ if open_exists :
991+ should_process = f .open (partition_id , epoch_id )
992+
993+ error = None
994+
995+ try :
996+ if should_process :
997+ for x in iterator :
998+ f .process (x )
999+ except Exception as ex :
1000+ error = ex
1001+ finally :
1002+ if close_exists :
1003+ f .close (error )
1004+ if error :
1005+ raise error
1006+
1007+ return iter ([])
1008+
1009+ func = func_with_open_process_close
1010+
1011+ serializer = AutoBatchedSerializer (PickleSerializer ())
1012+ wrapped_func = _wrap_function (self ._spark ._sc , func , serializer , serializer )
1013+ jForeachWriter = \
1014+ self ._spark ._sc ._jvm .org .apache .spark .sql .execution .python .PythonForeachWriter (
1015+ wrapped_func , self ._df ._jdf .schema ())
1016+ self ._jwrite .foreach (jForeachWriter )
1017+ return self
1018+
8571019 @ignore_unicode_prefix
8581020 @since (2.0 )
8591021 def start (self , path = None , format = None , outputMode = None , partitionBy = None , queryName = None ,
0 commit comments