Skip to content

Commit 44d8497

Browse files
author
Davies Liu
committed
add timezone support for DateType
1 parent 99d9d9c commit 44d8497

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

python/pyspark/sql/tests.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,14 +604,17 @@ def test_filter_with_datetime(self):
604604
self.assertEqual(0, df.filter(df.time > time).count())
605605

606606
def test_time_with_timezone(self):
607+
day = datetime.date.today()
607608
now = datetime.datetime.now()
608609
ts = time.mktime(now.timetuple()) + now.microsecond / 1e6
609610
# class in __main__ is not serializable
610611
from pyspark.sql.tests import UTC
611612
utc = UTC()
612613
utcnow = datetime.datetime.fromtimestamp(ts, utc)
613-
df = self.sqlCtx.createDataFrame([(now, utcnow)])
614-
now1, utcnow1 = df.first()
614+
df = self.sqlCtx.createDataFrame([(day, now, utcnow)])
615+
day1, now1, utcnow1 = df.first()
616+
# Pyrolite serialize java.sql.Date as datetime, will be fixed in new version
617+
self.assertEqual(day1.date(), day)
615618
# Pyrolite does not support microsecond, the error should be
616619
# less than 1 millisecond
617620
self.assertTrue(now - now1 < datetime.timedelta(0.001))

python/pyspark/sql/types.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -655,12 +655,15 @@ def _need_python_to_sql_conversion(dataType):
655655
_need_python_to_sql_conversion(dataType.valueType)
656656
elif isinstance(dataType, UserDefinedType):
657657
return True
658-
elif isinstance(dataType, TimestampType):
658+
elif isinstance(dataType, (DateType, TimestampType)):
659659
return True
660660
else:
661661
return False
662662

663663

664+
EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
665+
666+
664667
def _python_to_sql_converter(dataType):
665668
"""
666669
Returns a converter that converts a Python object into a SQL datum for the given type.
@@ -698,26 +701,30 @@ def converter(obj):
698701
return tuple(c(d.get(n)) for n, c in zip(names, converters))
699702
else:
700703
return tuple(c(v) for c, v in zip(converters, obj))
701-
else:
704+
elif obj is not None:
702705
raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
703706
return converter
704707
elif isinstance(dataType, ArrayType):
705708
element_converter = _python_to_sql_converter(dataType.elementType)
706-
return lambda a: [element_converter(v) for v in a]
709+
return lambda a: a and [element_converter(v) for v in a]
707710
elif isinstance(dataType, MapType):
708711
key_converter = _python_to_sql_converter(dataType.keyType)
709712
value_converter = _python_to_sql_converter(dataType.valueType)
710-
return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
713+
return lambda m: m and dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
711714

712715
elif isinstance(dataType, UserDefinedType):
713-
return lambda obj: dataType.serialize(obj)
716+
return lambda obj: obj and dataType.serialize(obj)
717+
718+
elif isinstance(dataType, DateType):
719+
return lambda d: d and d.toordinal() - EPOCH_ORDINAL
714720

715721
elif isinstance(dataType, TimestampType):
716722

717723
def to_posix_timstamp(dt):
718-
seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
719-
else time.mktime(dt.timetuple()))
720-
return int(seconds * 1e7 + dt.microsecond * 10)
724+
if dt:
725+
seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
726+
else time.mktime(dt.timetuple()))
727+
return int(seconds * 1e7 + dt.microsecond * 10)
721728
return to_posix_timstamp
722729

723730
else:

0 commit comments

Comments
 (0)