@@ -837,6 +837,168 @@ def trigger(self, processingTime=None, once=None, continuous=None):
837837 self ._jwrite = self ._jwrite .trigger (jTrigger )
838838 return self
839839
840+ @since (2.4 )
841+ def foreach (self , f ):
842+ """
843+ Sets the output of the streaming query to be processed using the provided writer ``f``.
844+ This is often used to write the output of a streaming query to arbitrary storage systems.
845+ The processing logic can be specified in two ways.
846+
847+ #. A **function** that takes a row as input.
848+ This is a simple way to express your processing logic. Note that this does
849+ not allow you to deduplicate generated data when failures cause reprocessing of
850+ some input data. That would require you to specify the processing logic in the next
851+ way.
852+
853+ #. An **object** with a ``process`` method and optional ``open`` and ``close`` methods.
854+ The object can have the following methods.
855+
856+ * ``open(partition_id, epoch_id)``: *Optional* method that initializes the processing
857+ (for example, open a connection, start a transaction, etc). Additionally, you can
858+ use the `partition_id` and `epoch_id` to deduplicate regenerated data
859+ (discussed later).
860+
861+ * ``process(row)``: *Non-optional* method that processes each :class:`Row`.
862+
863+ * ``close(error)``: *Optional* method that finalizes and cleans up (for example,
864+ close connection, commit transaction, etc.) after all rows have been processed.
865+
866+ The object will be used by Spark in the following way.
867+
868+ * A single copy of this object is responsible of all the data generated by a
869+ single task in a query. In other words, one instance is responsible for
870+ processing one partition of the data generated in a distributed manner.
871+
872+ * This object must be serializable because each task will get a fresh
873+ serialized-deserialized copy of the provided object. Hence, it is strongly
874+ recommended that any initialization for writing data (e.g. opening a
875+ connection or starting a transaction) is done after the `open(...)`
876+ method has been called, which signifies that the task is ready to generate data.
877+
878+ * The lifecycle of the methods are as follows.
879+
880+ For each partition with ``partition_id``:
881+
882+ ... For each batch/epoch of streaming data with ``epoch_id``:
883+
884+ ....... Method ``open(partitionId, epochId)`` is called.
885+
886+ ....... If ``open(...)`` returns true, for each row in the partition and
887+ batch/epoch, method ``process(row)`` is called.
888+
889+ ....... Method ``close(errorOrNull)`` is called with error (if any) seen while
890+ processing rows.
891+
892+ Important points to note:
893+
894+ * The `partitionId` and `epochId` can be used to deduplicate generated data when
895+ failures cause reprocessing of some input data. This depends on the execution
896+ mode of the query. If the streaming query is being executed in the micro-batch
897+ mode, then every partition represented by a unique tuple (partition_id, epoch_id)
898+ is guaranteed to have the same data. Hence, (partition_id, epoch_id) can be used
899+ to deduplicate and/or transactionally commit data and achieve exactly-once
900+ guarantees. However, if the streaming query is being executed in the continuous
901+ mode, then this guarantee does not hold and therefore should not be used for
902+ deduplication.
903+
904+ * The ``close()`` method (if exists) will be called if `open()` method exists and
905+ returns successfully (irrespective of the return value), except if the Python
906+ crashes in the middle.
907+
908+ .. note:: Evolving.
909+
910+ >>> # Print every row using a function
911+ >>> def print_row(row):
912+ ... print(row)
913+ ...
914+ >>> writer = sdf.writeStream.foreach(print_row)
915+ >>> # Print every row using a object with process() method
916+ >>> class RowPrinter:
917+ ... def open(self, partition_id, epoch_id):
918+ ... print("Opened %d, %d" % (partition_id, epoch_id))
919+ ... return True
920+ ... def process(self, row):
921+ ... print(row)
922+ ... def close(self, error):
923+ ... print("Closed with error: %s" % str(error))
924+ ...
925+ >>> writer = sdf.writeStream.foreach(RowPrinter())
926+ """
927+
928+ from pyspark .rdd import _wrap_function
929+ from pyspark .serializers import PickleSerializer , AutoBatchedSerializer
930+ from pyspark .taskcontext import TaskContext
931+
932+ if callable (f ):
933+ # The provided object is a callable function that is supposed to be called on each row.
934+ # Construct a function that takes an iterator and calls the provided function on each
935+ # row.
936+ def func_without_process (_ , iterator ):
937+ for x in iterator :
938+ f (x )
939+ return iter ([])
940+
941+ func = func_without_process
942+
943+ else :
944+ # The provided object is not a callable function. Then it is expected to have a
945+ # 'process(row)' method, and optional 'open(partition_id, epoch_id)' and
946+ # 'close(error)' methods.
947+
948+ if not hasattr (f , 'process' ):
949+ raise Exception ("Provided object does not have a 'process' method" )
950+
951+ if not callable (getattr (f , 'process' )):
952+ raise Exception ("Attribute 'process' in provided object is not callable" )
953+
954+ def doesMethodExist (method_name ):
955+ exists = hasattr (f , method_name )
956+ if exists and not callable (getattr (f , method_name )):
957+ raise Exception (
958+ "Attribute '%s' in provided object is not callable" % method_name )
959+ return exists
960+
961+ open_exists = doesMethodExist ('open' )
962+ close_exists = doesMethodExist ('close' )
963+
964+ def func_with_open_process_close (partition_id , iterator ):
965+ epoch_id = TaskContext .get ().getLocalProperty ('streaming.sql.batchId' )
966+ if epoch_id :
967+ epoch_id = int (epoch_id )
968+ else :
969+ raise Exception ("Could not get batch id from TaskContext" )
970+
971+ # Check if the data should be processed
972+ should_process = True
973+ if open_exists :
974+ should_process = f .open (partition_id , epoch_id )
975+
976+ error = None
977+
978+ try :
979+ if should_process :
980+ for x in iterator :
981+ f .process (x )
982+ except Exception as ex :
983+ error = ex
984+ finally :
985+ if close_exists :
986+ f .close (error )
987+ if error :
988+ raise error
989+
990+ return iter ([])
991+
992+ func = func_with_open_process_close
993+
994+ serializer = AutoBatchedSerializer (PickleSerializer ())
995+ wrapped_func = _wrap_function (self ._spark ._sc , func , serializer , serializer )
996+ jForeachWriter = \
997+ self ._spark ._sc ._jvm .org .apache .spark .sql .execution .python .PythonForeachWriter (
998+ wrapped_func , self ._df ._jdf .schema ())
999+ self ._jwrite .foreach (jForeachWriter )
1000+ return self
1001+
8401002 @ignore_unicode_prefix
8411003 @since (2.0 )
8421004 def start (self , path = None , format = None , outputMode = None , partitionBy = None , queryName = None ,
0 commit comments