From 8bfdb6897c6c2b1af2ae158c7355621576f07142 Mon Sep 17 00:00:00 2001 From: edorigatti Date: Fri, 1 Jun 2018 13:18:21 +0200 Subject: [PATCH 01/10] Revert "[SPARK-23754][PYTHON] Re-raising StopIteration in client code" This reverts commit 0ebb0c0d4dd3e192464dc5e0e6f01efa55b945ed. --- python/pyspark/rdd.py | 18 +++---------- python/pyspark/shuffle.py | 7 +++-- python/pyspark/sql/tests.py | 16 ----------- python/pyspark/sql/udf.py | 14 ++-------- python/pyspark/tests.py | 53 ------------------------------------- python/pyspark/util.py | 28 +++----------------- 6 files changed, 11 insertions(+), 125 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 14d9128502ab0..d5a237a5b2855 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -53,7 +53,6 @@ from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync -from pyspark.util import fail_on_stopiteration __all__ = ["RDD"] @@ -340,7 +339,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return map(fail_on_stopiteration(f), iterator) + return map(f, iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -355,7 +354,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(map(fail_on_stopiteration(f), iterator)) + return chain.from_iterable(map(f, iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -418,7 +417,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return filter(fail_on_stopiteration(f), iterator) + return filter(f, iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -799,8 +798,6 @@ def foreach(self, f): >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ - f = fail_on_stopiteration(f) - def processPartition(iterator): for x in iterator: f(x) @@ -850,8 +847,6 @@ def reduce(self, f): ... ValueError: Can not reduce() empty RDD """ - f = fail_on_stopiteration(f) - def func(iterator): iterator = iter(iterator) try: @@ -923,8 +918,6 @@ def fold(self, zeroValue, op): >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 """ - op = fail_on_stopiteration(op) - def func(iterator): acc = zeroValue for obj in iterator: @@ -957,9 +950,6 @@ def aggregate(self, zeroValue, seqOp, combOp): >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) (0, 0) """ - seqOp = fail_on_stopiteration(seqOp) - combOp = fail_on_stopiteration(combOp) - def func(iterator): acc = zeroValue for obj in iterator: @@ -1653,8 +1643,6 @@ def reduceByKeyLocally(self, func): >>> sorted(rdd.reduceByKeyLocally(add).items()) [('a', 2), ('b', 1)] """ - func = fail_on_stopiteration(func) - def reducePartition(iterator): m = {} for k, v in iterator: diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index bd0ac0039ffe1..02c773302e9da 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -28,7 +28,6 @@ import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ CompressedSerializer, AutoBatchedSerializer -from pyspark.util import fail_on_stopiteration try: @@ -95,9 +94,9 @@ class Aggregator(object): """ def __init__(self, createCombiner, mergeValue, mergeCombiners): - self.createCombiner = fail_on_stopiteration(createCombiner) - self.mergeValue = fail_on_stopiteration(mergeValue) - self.mergeCombiners = fail_on_stopiteration(mergeCombiners) + self.createCombiner = createCombiner + self.mergeValue = mergeValue + self.mergeCombiners = mergeCombiners class SimpleAggregator(Aggregator): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index a2450932e303d..c7bd8f01b907f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -900,22 +900,6 @@ def __call__(self, x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) - def test_stopiteration_in_udf(self): - # test for SPARK-23754 - from pyspark.sql.functions import udf - from py4j.protocol import Py4JJavaError - - def foo(x): - raise StopIteration() - - with self.assertRaises(Py4JJavaError) as cm: - self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show() - - self.assertIn( - "Caught StopIteration thrown from user's code; failing the task", - cm.exception.java_exception.toString() - ) - def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index c8fb49d7c2b65..9dbe49b831cef 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -25,7 +25,7 @@ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string,\ to_arrow_type, to_arrow_schema -from pyspark.util import _get_argspec, fail_on_stopiteration +from pyspark.util import _get_argspec __all__ = ["UDFRegistration"] @@ -157,17 +157,7 @@ def _create_judf(self): spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext - func = fail_on_stopiteration(self.func) - - # for pandas UDFs the worker needs to know if the function takes - # one or two arguments, but the signature is lost when wrapping with - # fail_on_stopiteration, so we store it here - if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, - PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF): - func._argspec = _get_argspec(self.func) - - wrapped_func = _wrap_function(sc, func, self.returnType) + wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( self._name, wrapped_func, jdt, self.evalType, self.deterministic) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 30723b8e15b36..baeb5ccabb58f 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -161,37 +161,6 @@ def gen_gs(N, step=1): self.assertEqual(k, len(vs)) self.assertEqual(list(range(k)), list(vs)) - def test_stopiteration_is_raised(self): - - def stopit(*args, **kwargs): - raise StopIteration() - - def legit_create_combiner(x): - return [x] - - def legit_merge_value(x, y): - return x.append(y) or x - - def legit_merge_combiners(x, y): - return x.extend(y) or x - - data = [(x % 2, x) for x in range(100)] - - # wrong create combiner - m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - m.mergeValues(data) - - # wrong merge value - m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - m.mergeValues(data) - - # wrong merge combiners - m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) - with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: - m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) - class SorterTests(unittest.TestCase): def test_in_memory_sort(self): @@ -1291,28 +1260,6 @@ def test_pipe_unicode(self): result = rdd.pipe('cat').collect() self.assertEqual(data, result) - def test_stopiteration_in_client_code(self): - - def stopit(*x): - raise StopIteration() - - seq_rdd = self.sc.parallelize(range(10)) - keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) - - self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) - self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit) - self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) - self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) - - # the exception raised is non-deterministic - self.assertRaises((Py4JJavaError, RuntimeError), - seq_rdd.aggregate, 0, stopit, lambda *x: 1) - self.assertRaises((Py4JJavaError, RuntimeError), - seq_rdd.aggregate, 0, lambda *x: 1, stopit) - class ProfilerTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index e95a9b523393f..59cc2a6329350 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -53,16 +53,11 @@ def _get_argspec(f): """ Get argspec of a function. Supports both Python 2 and Python 3. """ - - if hasattr(f, '_argspec'): - # only used for pandas UDF: they wrap the user function, losing its signature - # workers need this signature, so UDF saves it here - argspec = f._argspec - elif sys.version_info[0] < 3: + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. + if sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: - # `getargspec` is deprecated since python3.0 (incompatible with function annotations). - # See SPARK-23569. argspec = inspect.getfullargspec(f) return argspec @@ -94,23 +89,6 @@ def majorMinorVersion(sparkVersion): " version numbers.") -def fail_on_stopiteration(f): - """ - Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' - prevents silent loss of data when 'f' is used in a for loop - """ - def wrapper(*args, **kwargs): - try: - return f(*args, **kwargs) - except StopIteration as exc: - raise RuntimeError( - "Caught StopIteration thrown from user's code; failing the task", - exc - ) - - return wrapper - - if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() From f755a48aca89e7a3514e719d3523eb3309fed488 Mon Sep 17 00:00:00 2001 From: edorigatti Date: Fri, 1 Jun 2018 15:29:10 +0200 Subject: [PATCH 02/10] re-raising StopIteration in user code --- python/pyspark/rdd.py | 15 ++++++++--- python/pyspark/shuffle.py | 7 ++--- python/pyspark/sql/tests.py | 41 ++++++++++++++++++++++++++++ python/pyspark/tests.py | 53 +++++++++++++++++++++++++++++++++++++ python/pyspark/util.py | 17 ++++++++++++ python/pyspark/worker.py | 18 ++++++++----- 6 files changed, 138 insertions(+), 13 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d5a237a5b2855..2834397b6a213 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -53,6 +53,7 @@ from pyspark.shuffle import Aggregator, ExternalMerger, \ get_used_memory, ExternalSorter, ExternalGroupBy from pyspark.traceback_utils import SCCallSiteSync +from pyspark.util import fail_on_stopiteration __all__ = ["RDD"] @@ -339,7 +340,7 @@ def map(self, f, preservesPartitioning=False): [('a', 1), ('b', 1), ('c', 1)] """ def func(_, iterator): - return map(f, iterator) + return map(fail_on_stopiteration(f), iterator) return self.mapPartitionsWithIndex(func, preservesPartitioning) def flatMap(self, f, preservesPartitioning=False): @@ -354,7 +355,7 @@ def flatMap(self, f, preservesPartitioning=False): [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)] """ def func(s, iterator): - return chain.from_iterable(map(f, iterator)) + return chain.from_iterable(map(fail_on_stopiteration(f), iterator)) return self.mapPartitionsWithIndex(func, preservesPartitioning) def mapPartitions(self, f, preservesPartitioning=False): @@ -417,7 +418,7 @@ def filter(self, f): [2, 4] """ def func(iterator): - return filter(f, iterator) + return filter(fail_on_stopiteration(f), iterator) return self.mapPartitions(func, True) def distinct(self, numPartitions=None): @@ -847,6 +848,8 @@ def reduce(self, f): ... ValueError: Can not reduce() empty RDD """ + f = fail_on_stopiteration(f) + def func(iterator): iterator = iter(iterator) try: @@ -918,6 +921,8 @@ def fold(self, zeroValue, op): >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add) 15 """ + op = fail_on_stopiteration(op) + def func(iterator): acc = zeroValue for obj in iterator: @@ -950,6 +955,9 @@ def aggregate(self, zeroValue, seqOp, combOp): >>> sc.parallelize([]).aggregate((0, 0), seqOp, combOp) (0, 0) """ + seqOp = fail_on_stopiteration(seqOp) + combOp = fail_on_stopiteration(combOp) + def func(iterator): acc = zeroValue for obj in iterator: @@ -1628,6 +1636,7 @@ def reduceByKey(self, func, numPartitions=None, partitionFunc=portable_hash): >>> sorted(rdd.reduceByKey(add).collect()) [('a', 2), ('b', 1)] """ + func = fail_on_stopiteration(func) return self.combineByKey(lambda x: x, func, func, numPartitions, partitionFunc) def reduceByKeyLocally(self, func): diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py index 02c773302e9da..bd0ac0039ffe1 100644 --- a/python/pyspark/shuffle.py +++ b/python/pyspark/shuffle.py @@ -28,6 +28,7 @@ import pyspark.heapq3 as heapq from pyspark.serializers import BatchedSerializer, PickleSerializer, FlattenedValuesSerializer, \ CompressedSerializer, AutoBatchedSerializer +from pyspark.util import fail_on_stopiteration try: @@ -94,9 +95,9 @@ class Aggregator(object): """ def __init__(self, createCombiner, mergeValue, mergeCombiners): - self.createCombiner = createCombiner - self.mergeValue = mergeValue - self.mergeCombiners = mergeCombiners + self.createCombiner = fail_on_stopiteration(createCombiner) + self.mergeValue = fail_on_stopiteration(mergeValue) + self.mergeCombiners = fail_on_stopiteration(mergeCombiners) class SimpleAggregator(Aggregator): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c7bd8f01b907f..5848d0531f4a5 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -900,6 +900,47 @@ def __call__(self, x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) + def test_stopiteration_in_udf(self): + return + + from pyspark.sql.functions import udf, pandas_udf, PandasUDFType + from py4j.protocol import Py4JJavaError + + def do_test(action, *args, **kwargs): + exc_message = "Caught StopIteration thrown from user's code; failing the task" + with self.assertRaisesRegexp(Py4JJavaError, exc_message) as cm: + action(*args, **kwargs) + + def foo(x): + raise StopIteration() + + def foofoo(x, y): + raise StopIteration() + + df = self.spark.range(0, 100) + + # plain udf (test for SPARK-23754) + do_test(df.withColumn('v', udf(foo)('id')).show) + + # pandas scalar udf + do_test(df.withColumn( + 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id') + ).show) + + # pandas grouped map + do_test(df.groupBy('id').apply( + pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP) + ).show) + + do_test(df.groupBy('id').apply( + pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP) + ).show) + + # pandas grouped agg + do_test(df.groupBy('id').agg( + pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id') + ).show) + def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index baeb5ccabb58f..25a912cf09cca 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -161,6 +161,37 @@ def gen_gs(N, step=1): self.assertEqual(k, len(vs)) self.assertEqual(list(range(k)), list(vs)) + def test_stopiteration_is_raised(self): + + def stopit(*args, **kwargs): + raise StopIteration() + + def legit_create_combiner(x): + return [x] + + def legit_merge_value(x, y): + return x.append(y) or x + + def legit_merge_combiners(x, y): + return x.extend(y) or x + + data = [(x % 2, x) for x in range(100)] + + # wrong create combiner + m = ExternalMerger(Aggregator(stopit, legit_merge_value, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge value + m = ExternalMerger(Aggregator(legit_create_combiner, stopit, legit_merge_combiners), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeValues(data) + + # wrong merge combiners + m = ExternalMerger(Aggregator(legit_create_combiner, legit_merge_value, stopit), 20) + with self.assertRaises((Py4JJavaError, RuntimeError)) as cm: + m.mergeCombiners(map(lambda x_y1: (x_y1[0], [x_y1[1]]), data)) + class SorterTests(unittest.TestCase): def test_in_memory_sort(self): @@ -1260,6 +1291,28 @@ def test_pipe_unicode(self): result = rdd.pipe('cat').collect() self.assertEqual(data, result) + def test_stopiteration_in_user_code(self): + + def stopit(*x): + raise StopIteration() + + seq_rdd = self.sc.parallelize(range(10)) + keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) + + self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) + self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) + + # the exception raised is non-deterministic + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaises((Py4JJavaError, RuntimeError), + seq_rdd.aggregate, 0, lambda *x: 1, stopit) + class ProfilerTests(PySparkTestCase): diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 59cc2a6329350..fa1b1c2da0b21 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -89,6 +89,23 @@ def majorMinorVersion(sparkVersion): " version numbers.") +def fail_on_stopiteration(f): + """ + Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' + prevents silent loss of data when 'f' is used in a for loop + """ + def wrapper(*args, **kwargs): + try: + return f(*args, **kwargs) + except StopIteration as exc: + raise RuntimeError( + "Caught StopIteration thrown from user's code; failing the task", + exc + ) + + return wrapper + + if __name__ == "__main__": import doctest (failure_count, test_count) = doctest.testmod() diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index fbcb8af8bfb24..6bf3627c43fb7 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -35,7 +35,7 @@ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type -from pyspark.util import _get_argspec +from pyspark.util import _get_argspec, fail_on_stopiteration from pyspark import shuffle pickleSer = PickleSerializer() @@ -92,10 +92,9 @@ def verify_result_length(*a): return lambda *a: (verify_result_length(*a), arrow_return_type) -def wrap_grouped_map_pandas_udf(f, return_type): +def wrap_grouped_map_pandas_udf(f, return_type, argspec): def wrapped(key_series, value_series): import pandas as pd - argspec = _get_argspec(f) if len(argspec.args) == 1: result = f(pd.concat(value_series, axis=1)) @@ -140,15 +139,20 @@ def read_single_udf(pickleSer, infile, eval_type): else: row_func = chain(row_func, f) + # make sure StopIteration's raised in the user code are not + # ignored, but re-raised as RuntimeError's + func = fail_on_stopiteration(row_func) + # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: - return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type) + return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: - return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type) + argspec = _get_argspec(row_func) # signature was lost when wrapping it + return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, argspec) elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF: - return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type) + return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_BATCHED_UDF: - return arg_offsets, wrap_udf(row_func, return_type) + return arg_offsets, wrap_udf(func, return_type) else: raise ValueError("Unknown eval type: {}".format(eval_type)) From 1981d8e0328dd3cba06e1765ef1397075f036e1e Mon Sep 17 00:00:00 2001 From: edorigatti Date: Fri, 1 Jun 2018 16:02:52 +0200 Subject: [PATCH 03/10] fixed tests, added test for rdd.foreach --- python/pyspark/rdd.py | 5 ++++- python/pyspark/sql/tests.py | 2 -- python/pyspark/tests.py | 4 +++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 2834397b6a213..14d9128502ab0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -799,6 +799,8 @@ def foreach(self, f): >>> def f(x): print(x) >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f) """ + f = fail_on_stopiteration(f) + def processPartition(iterator): for x in iterator: f(x) @@ -1636,7 +1638,6 @@ def reduceByKey(self, func, numPartitions=None, partitionFunc=portable_hash): >>> sorted(rdd.reduceByKey(add).collect()) [('a', 2), ('b', 1)] """ - func = fail_on_stopiteration(func) return self.combineByKey(lambda x: x, func, func, numPartitions, partitionFunc) def reduceByKeyLocally(self, func): @@ -1652,6 +1653,8 @@ def reduceByKeyLocally(self, func): >>> sorted(rdd.reduceByKeyLocally(add).items()) [('a', 2), ('b', 1)] """ + func = fail_on_stopiteration(func) + def reducePartition(iterator): m = {} for k, v in iterator: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 5848d0531f4a5..6dbe6aeff0556 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -901,8 +901,6 @@ def __call__(self, x): self.assertEqual(return_type, f_.returnType) def test_stopiteration_in_udf(self): - return - from pyspark.sql.functions import udf, pandas_udf, PandasUDFType from py4j.protocol import Py4JJavaError diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 25a912cf09cca..96bb1ee4066ce 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1303,11 +1303,13 @@ def stopit(*x): self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) - self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit) self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) + self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) # the exception raised is non-deterministic + self.assertRaises((Py4JJavaError, RuntimeError), + keyed_rdd.reduceByKeyLocally, stopit) self.assertRaises((Py4JJavaError, RuntimeError), seq_rdd.aggregate, 0, stopit, lambda *x: 1) self.assertRaises((Py4JJavaError, RuntimeError), From 39ab167c3d28bf609d752684f09d9d9851bc87fb Mon Sep 17 00:00:00 2001 From: edorigatti Date: Fri, 1 Jun 2018 17:35:19 +0200 Subject: [PATCH 04/10] checking exception message --- python/pyspark/tests.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 96bb1ee4066ce..e592c1984c578 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1298,22 +1298,24 @@ def stopit(*x): seq_rdd = self.sc.parallelize(range(10)) keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) + msg = "Caught StopIteration thrown from user's code; failing the task" - self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) - self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) - self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) - self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, + seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) # the exception raised is non-deterministic - self.assertRaises((Py4JJavaError, RuntimeError), - keyed_rdd.reduceByKeyLocally, stopit) - self.assertRaises((Py4JJavaError, RuntimeError), - seq_rdd.aggregate, 0, stopit, lambda *x: 1) - self.assertRaises((Py4JJavaError, RuntimeError), - seq_rdd.aggregate, 0, lambda *x: 1, stopit) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, lambda *x: 1, stopit) class ProfilerTests(PySparkTestCase): From 8505de28231e99b63371d0798545b693692cbce4 Mon Sep 17 00:00:00 2001 From: edorigatti Date: Fri, 1 Jun 2018 19:34:06 +0200 Subject: [PATCH 05/10] fixed udf test --- python/pyspark/sql/tests.py | 76 ++++++++++++++++++------------------- 1 file changed, 37 insertions(+), 39 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6dbe6aeff0556..e04433eca8248 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -900,45 +900,6 @@ def __call__(self, x): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) - def test_stopiteration_in_udf(self): - from pyspark.sql.functions import udf, pandas_udf, PandasUDFType - from py4j.protocol import Py4JJavaError - - def do_test(action, *args, **kwargs): - exc_message = "Caught StopIteration thrown from user's code; failing the task" - with self.assertRaisesRegexp(Py4JJavaError, exc_message) as cm: - action(*args, **kwargs) - - def foo(x): - raise StopIteration() - - def foofoo(x, y): - raise StopIteration() - - df = self.spark.range(0, 100) - - # plain udf (test for SPARK-23754) - do_test(df.withColumn('v', udf(foo)('id')).show) - - # pandas scalar udf - do_test(df.withColumn( - 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id') - ).show) - - # pandas grouped map - do_test(df.groupBy('id').apply( - pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP) - ).show) - - do_test(df.groupBy('id').apply( - pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP) - ).show) - - # pandas grouped agg - do_test(df.groupBy('id').agg( - pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id') - ).show) - def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column @@ -4119,6 +4080,43 @@ def foo(df): def foo(k, v, w): return k + def test_stopiteration_in_udf(self): + from pyspark.sql.functions import udf, pandas_udf, PandasUDFType + from py4j.protocol import Py4JJavaError + + def foo(x): + raise StopIteration() + + def foofoo(x, y): + raise StopIteration() + + exc_message = "Caught StopIteration thrown from user's code; failing the task" + df = self.spark.range(0, 100) + + # plain udf (test for SPARK-23754) + self.assertRaisesRegexp(Py4JJavaError, exc_message, df.withColumn( + 'v', udf(foo)('id') + ).collect()) + + # pandas scalar udf + self.assertRaisesRegexp(Py4JJavaError, exc_message, df.withColumn( + 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id') + ).collect()) + + # pandas grouped map + self.assertRaisesRegexp(Py4JJavaError, exc_message, df.groupBy('id').apply( + pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP) + ).collect()) + + self.assertRaisesRegexp(Py4JJavaError, exc_message, df.groupBy('id').apply( + pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP) + ).collect()) + + # pandas grouped agg + self.assertRaisesRegexp(Py4JJavaError, exc_message, df.groupBy('id').agg( + pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id') + ).collect()) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, From 20d26a61b4dcae1b46d10f0d111af3e4be3e578b Mon Sep 17 00:00:00 2001 From: edorigatti Date: Mon, 4 Jun 2018 10:23:15 +0200 Subject: [PATCH 06/10] catching both exceptions --- python/pyspark/sql/tests.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e04433eca8248..c753aafdc0b36 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4091,29 +4091,30 @@ def foofoo(x, y): raise StopIteration() exc_message = "Caught StopIteration thrown from user's code; failing the task" + excs = (Py4JJavaError, RuntimeError) df = self.spark.range(0, 100) # plain udf (test for SPARK-23754) - self.assertRaisesRegexp(Py4JJavaError, exc_message, df.withColumn( + self.assertRaisesRegexp(excs, exc_message, df.withColumn( 'v', udf(foo)('id') ).collect()) # pandas scalar udf - self.assertRaisesRegexp(Py4JJavaError, exc_message, df.withColumn( + self.assertRaisesRegexp(excs, exc_message, df.withColumn( 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id') ).collect()) # pandas grouped map - self.assertRaisesRegexp(Py4JJavaError, exc_message, df.groupBy('id').apply( + self.assertRaisesRegexp(excs, exc_message, df.groupBy('id').apply( pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP) ).collect()) - self.assertRaisesRegexp(Py4JJavaError, exc_message, df.groupBy('id').apply( + self.assertRaisesRegexp(excs, exc_message, df.groupBy('id').apply( pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP) ).collect()) # pandas grouped agg - self.assertRaisesRegexp(Py4JJavaError, exc_message, df.groupBy('id').agg( + self.assertRaisesRegexp(excs, exc_message, df.groupBy('id').agg( pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id') ).collect()) From 4cc2b5e5ba38e0312affe393468db515ebb61994 Mon Sep 17 00:00:00 2001 From: edorigatti Date: Mon, 4 Jun 2018 11:50:05 +0200 Subject: [PATCH 07/10] fix --- python/pyspark/sql/tests.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index c753aafdc0b36..801c4b0d9274c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4091,32 +4091,31 @@ def foofoo(x, y): raise StopIteration() exc_message = "Caught StopIteration thrown from user's code; failing the task" - excs = (Py4JJavaError, RuntimeError) df = self.spark.range(0, 100) # plain udf (test for SPARK-23754) - self.assertRaisesRegexp(excs, exc_message, df.withColumn( + self.assertRaisesRegexp(Py4JJavaError, exc_message, df.withColumn( 'v', udf(foo)('id') - ).collect()) + ).collect) # pandas scalar udf - self.assertRaisesRegexp(excs, exc_message, df.withColumn( + self.assertRaisesRegexp(Py4JJavaError, exc_message, df.withColumn( 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id') - ).collect()) + ).collect) # pandas grouped map - self.assertRaisesRegexp(excs, exc_message, df.groupBy('id').apply( + self.assertRaisesRegexp(Py4JJavaError, exc_message, df.groupBy('id').apply( pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP) - ).collect()) + ).collect) - self.assertRaisesRegexp(excs, exc_message, df.groupBy('id').apply( + self.assertRaisesRegexp(Py4JJavaError, exc_message, df.groupBy('id').apply( pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP) - ).collect()) + ).collect) # pandas grouped agg - self.assertRaisesRegexp(excs, exc_message, df.groupBy('id').agg( + self.assertRaisesRegexp(Py4JJavaError, exc_message, df.groupBy('id').agg( pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id') - ).collect()) + ).collect) @unittest.skipIf( From c60225a1169e9d6395bb2c9035c908b6542a065e Mon Sep 17 00:00:00 2001 From: edorigatti Date: Tue, 5 Jun 2018 08:01:39 +0200 Subject: [PATCH 08/10] moved comment --- python/pyspark/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/util.py b/python/pyspark/util.py index fa1b1c2da0b21..784999bd734cd 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -53,11 +53,11 @@ def _get_argspec(f): """ Get argspec of a function. Supports both Python 2 and Python 3. """ - # `getargspec` is deprecated since python3.0 (incompatible with function annotations). - # See SPARK-23569. if sys.version_info[0] < 3: argspec = inspect.getargspec(f) else: + # `getargspec` is deprecated since python3.0 (incompatible with function annotations). + # See SPARK-23569. argspec = inspect.getfullargspec(f) return argspec From 7cb95568565bcfc44f5409707b08878774d02cc1 Mon Sep 17 00:00:00 2001 From: edorigatti Date: Thu, 7 Jun 2018 14:17:08 +0200 Subject: [PATCH 09/10] formatting --- python/pyspark/sql/tests.py | 48 +++++++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 801c4b0d9274c..31cd138040855 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -4094,28 +4094,46 @@ def foofoo(x, y): df = self.spark.range(0, 100) # plain udf (test for SPARK-23754) - self.assertRaisesRegexp(Py4JJavaError, exc_message, df.withColumn( - 'v', udf(foo)('id') - ).collect) + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.withColumn('v', udf(foo)('id')).collect + ) # pandas scalar udf - self.assertRaisesRegexp(Py4JJavaError, exc_message, df.withColumn( - 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id') - ).collect) + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.withColumn( + 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id') + ).collect + ) # pandas grouped map - self.assertRaisesRegexp(Py4JJavaError, exc_message, df.groupBy('id').apply( - pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP) - ).collect) + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').apply( + pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP) + ).collect + ) - self.assertRaisesRegexp(Py4JJavaError, exc_message, df.groupBy('id').apply( - pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP) - ).collect) + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').apply( + pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP) + ).collect + ) # pandas grouped agg - self.assertRaisesRegexp(Py4JJavaError, exc_message, df.groupBy('id').agg( - pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id') - ).collect) + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').agg( + pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id') + ).collect + ) @unittest.skipIf( From 9724640c534f3f1600ae3c37988479e7d0500cd0 Mon Sep 17 00:00:00 2001 From: edorigatti Date: Thu, 7 Jun 2018 14:17:29 +0200 Subject: [PATCH 10/10] explaining in comments --- python/pyspark/tests.py | 5 ++++- python/pyspark/util.py | 2 +- python/pyspark/worker.py | 4 ++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index e592c1984c578..18b2f251dc9fd 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1309,7 +1309,10 @@ def stopit(*x): self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) - # the exception raised is non-deterministic + # these methods call the user function both in the driver and in the executor + # the exception raised is different according to where the StopIteration happens + # RuntimeError is raised if in the driver + # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker) self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, keyed_rdd.reduceByKeyLocally, stopit) self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 784999bd734cd..f015542c8799d 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -92,7 +92,7 @@ def majorMinorVersion(sparkVersion): def fail_on_stopiteration(f): """ Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' - prevents silent loss of data when 'f' is used in a for loop + prevents silent loss of data when 'f' is used in a for loop in Spark code """ def wrapper(*args, **kwargs): try: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 6bf3627c43fb7..a30d6bf523a50 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -139,8 +139,8 @@ def read_single_udf(pickleSer, infile, eval_type): else: row_func = chain(row_func, f) - # make sure StopIteration's raised in the user code are not - # ignored, but re-raised as RuntimeError's + # make sure StopIteration's raised in the user code are not ignored + # when they are processed in a for loop, raise them as RuntimeError's instead func = fail_on_stopiteration(row_func) # the last returnType will be the return type of UDF