Skip to content

Commit addd35f

Browse files
committed
added check for pandas_udf return is a timestamp with tz, added comments on conversion function input and output
1 parent 4d40893 commit addd35f

File tree

3 files changed

+36
-9
lines changed

3 files changed

+36
-9
lines changed

python/pyspark/sql/dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1885,8 +1885,8 @@ def toPandas(self):
18851885
tables = self._collectAsArrow()
18861886
if tables:
18871887
table = pyarrow.concat_tables(tables)
1888-
df = table.to_pandas()
1889-
return _check_dataframe_localize_timestamps(df)
1888+
pdf = table.to_pandas()
1889+
return _check_dataframe_localize_timestamps(pdf)
18901890
else:
18911891
return pd.DataFrame.from_records([], columns=self.columns)
18921892
except ImportError as e:

python/pyspark/sql/tests.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3458,6 +3458,24 @@ def check_data(idx, timestamp, timestamp_copy):
34583458
self.assertEquals(data[i][1], result[i][1]) # "timestamp" col
34593459
self.assertTrue(result[i][3]) # "is_equal" data in udf was as expected
34603460

3461+
def test_vectorized_udf_return_timestamp_tz(self):
3462+
from pyspark.sql.functions import pandas_udf, col
3463+
import pandas as pd
3464+
df = self.spark.range(10)
3465+
3466+
@pandas_udf(returnType=TimestampType())
3467+
def gen_timestamps(id):
3468+
ts = [pd.Timestamp(i, unit='D', tz='America/Los_Angeles') for i in id]
3469+
return pd.Series(ts)
3470+
3471+
result = df.withColumn("ts", gen_timestamps(col("id"))).collect()
3472+
spark_ts_t = TimestampType()
3473+
for r in result:
3474+
i, ts = r
3475+
ts_tz = pd.Timestamp(i, unit='D', tz='America/Los_Angeles').to_pydatetime()
3476+
expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz))
3477+
self.assertEquals(expected, ts)
3478+
34613479

34623480
@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
34633481
class GroupbyApplyTests(ReusedPySparkTestCase):

python/pyspark/sql/types.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,24 +1629,33 @@ def to_arrow_type(dt):
16291629
return arrow_type
16301630

16311631

1632-
def _check_dataframe_localize_timestamps(df):
1633-
""" Convert timezone aware timestamps to timezone-naive in local time
1632+
def _check_dataframe_localize_timestamps(pdf):
1633+
"""
1634+
Convert timezone aware timestamps to timezone-naive in local time
1635+
1636+
:param pdf: pandas.DataFrame
1637+
:return pandas.DataFrame where any timezone aware columns have be converted to tz-naive
16341638
"""
16351639
from pandas.api.types import is_datetime64tz_dtype
1636-
for column, series in df.iteritems():
1640+
for column, series in pdf.iteritems():
16371641
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
16381642
if is_datetime64tz_dtype(series.dtype):
1639-
df[column] = series.dt.tz_convert('tzlocal()').dt.tz_localize(None)
1640-
return df
1643+
pdf[column] = series.dt.tz_convert('tzlocal()').dt.tz_localize(None)
1644+
return pdf
16411645

16421646

16431647
def _check_series_convert_timestamps_internal(s):
1644-
""" Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage
16451648
"""
1646-
from pandas.api.types import is_datetime64_dtype
1649+
Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage
1650+
:param s: a pandas.Series
1651+
:return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone
1652+
"""
1653+
from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype
16471654
# TODO: handle nested timestamps, such as ArrayType(TimestampType())?
16481655
if is_datetime64_dtype(s.dtype):
16491656
return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC')
1657+
elif is_datetime64tz_dtype(s.dtype):
1658+
return s.dt.tz_convert('UTC')
16501659
else:
16511660
return s
16521661

0 commit comments

Comments
 (0)