Skip to content

Commit 6a29aa4

Browse files
author
Davies Liu
committed
support datetime with timezone
1 parent eb4632f commit 6a29aa4

File tree

3 files changed

+45
-2
lines changed

3 files changed

+45
-2
lines changed

python/pyspark/sql/_types.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import decimal
2020
import time
2121
import datetime
22+
import calendar
2223
import keyword
2324
import warnings
2425
import json
@@ -638,6 +639,8 @@ def _need_python_to_sql_conversion(dataType):
638639
elif isinstance(dataType, MapType):
639640
return _need_python_to_sql_conversion(dataType.keyType) or \
640641
_need_python_to_sql_conversion(dataType.valueType)
642+
elif isinstance(dataType, TimestampType):
643+
return True
641644
elif isinstance(dataType, UserDefinedType):
642645
return True
643646
else:
@@ -691,6 +694,16 @@ def converter(obj):
691694
key_converter = _python_to_sql_converter(dataType.keyType)
692695
value_converter = _python_to_sql_converter(dataType.valueType)
693696
return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
697+
698+
elif isinstance(dataType, TimestampType):
699+
700+
def to_posix_timstamp(dt):
701+
if dt.tzinfo is None:
702+
return time.mktime(dt.timetuple()) + dt.microsecond / 1e6
703+
else:
704+
return calendar.timegm(dt.utctimetuple()) + dt.microsecond / 1e6
705+
return to_posix_timstamp
706+
694707
elif isinstance(dataType, UserDefinedType):
695708
return lambda obj: dataType.serialize(obj)
696709
else:

python/pyspark/sql/tests.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import tempfile
2727
import pickle
2828
import functools
29+
import time
2930
import datetime
3031

3132
import py4j
@@ -46,6 +47,20 @@
4647
from pyspark.sql.functions import UserDefinedFunction
4748

4849

50+
class UTC(datetime.tzinfo):
51+
"""UTC"""
52+
ZERO = datetime.timedelta(0)
53+
54+
def utcoffset(self, dt):
55+
return self.ZERO
56+
57+
def tzname(self, dt):
58+
return "UTC"
59+
60+
def dst(self, dt):
61+
return self.ZERO
62+
63+
4964
class ExamplePointUDT(UserDefinedType):
5065
"""
5166
User-defined type (UDT) for ExamplePoint.
@@ -571,6 +586,20 @@ def test_filter_with_datetime(self):
571586
self.assertEqual(0, df.filter(df.date > date).count())
572587
self.assertEqual(0, df.filter(df.time > time).count())
573588

589+
def test_time_with_timezone(self):
590+
now = datetime.datetime.now()
591+
ts = time.mktime(now.timetuple()) + now.microsecond / 1e6
592+
# class in __main__ is not serializable
593+
from pyspark.sql.tests import UTC
594+
utc = UTC()
595+
utcnow = datetime.datetime.fromtimestamp(ts, utc)
596+
df = self.sqlCtx.createDataFrame([(now, utcnow)])
597+
now1, utcnow1 = df.first()
598+
# Spark SQL does not support microsecond, the error should be
599+
# less than 1 millisecond
600+
self.assertTrue(now1 - now < datetime.timedelta(0.001))
601+
self.assertTrue(utcnow1 - now < datetime.timedelta(0.001))
602+
574603
def test_dropna(self):
575604
schema = StructType([
576605
StructField("name", StringType(), True),

sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ import org.apache.spark.annotation.DeveloperApi
2828
import org.apache.spark.api.python.{PythonBroadcast, PythonRDD}
2929
import org.apache.spark.broadcast.Broadcast
3030
import org.apache.spark.rdd.RDD
31-
import org.apache.spark.sql.catalyst.expressions.Row
32-
import org.apache.spark.sql.catalyst.expressions._
31+
import org.apache.spark.sql.catalyst.expressions.{Row, _}
3332
import org.apache.spark.sql.catalyst.plans.logical
3433
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
3534
import org.apache.spark.sql.catalyst.rules.Rule
@@ -184,6 +183,8 @@ object EvaluatePython {
184183

185184
case (c: java.util.Calendar, TimestampType) =>
186185
new java.sql.Timestamp(c.getTime().getTime())
186+
case (c: Double, TimestampType) =>
187+
new java.sql.Timestamp((c * 1000).toLong)
187188

188189
case (_, udt: UserDefinedType[_]) =>
189190
fromJava(obj, udt.sqlType)

0 commit comments

Comments
 (0)