diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 4c16b5fc26f3..82abf1947c81 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -216,9 +216,10 @@ def _create_batch(series, timezone): :param timezone: A timezone to respect when handling timestamp values :return: Arrow RecordBatch """ - - from pyspark.sql.types import _check_series_convert_timestamps_internal + import decimal + from distutils.version import LooseVersion import pyarrow as pa + from pyspark.sql.types import _check_series_convert_timestamps_internal # Make input conform to [(series1, type1), (series2, type2), ...] if not isinstance(series, (list, tuple)) or \ (len(series) == 2 and isinstance(series[1], pa.DataType)): @@ -236,6 +237,11 @@ def create_array(s, t): # TODO: need decode before converting to Arrow in Python 2 return pa.Array.from_pandas(s.apply( lambda v: v.decode("utf-8") if isinstance(v, str) else v), mask=mask, type=t) + elif t is not None and pa.types.is_decimal(t) and \ + LooseVersion("0.9.0") <= LooseVersion(pa.__version__) < LooseVersion("0.10.0"): + # TODO: see ARROW-2432. Remove when the minimum PyArrow version becomes 0.10.0. + return pa.Array.from_pandas(s.apply( + lambda v: decimal.Decimal('NaN') if v is None else v), mask=mask, type=t) return pa.Array.from_pandas(s, mask=mask, type=t) arrs = [create_array(s, t) for s, t in series]