From 5aa8b9e72aee17ffa51f4cb1048f5a3f93a5a380 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 13 Jul 2017 09:53:31 -0700 Subject: [PATCH 01/34] added date type and started test, still some issue with time difference --- .../sql/execution/arrow/ArrowConverters.scala | 5 +- .../arrow/ArrowConvertersSuite.scala | 71 +++++++++++++++++-- 2 files changed, 69 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 6af5c7342237..7419f9fcda68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -21,17 +21,15 @@ import java.io.ByteArrayOutputStream import java.nio.channels.Channels import scala.collection.JavaConverters._ - import io.netty.buffer.ArrowBuf import org.apache.arrow.memory.{BufferAllocator, RootAllocator} import org.apache.arrow.vector._ import org.apache.arrow.vector.BaseValueVector.BaseMutator import org.apache.arrow.vector.file._ import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} -import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel - import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -84,6 +82,7 @@ private[sql] object ArrowConverters { case ByteType => new ArrowType.Int(8, true) case StringType => ArrowType.Utf8.INSTANCE case BinaryType => ArrowType.Binary.INSTANCE + case DateType => new ArrowType.Date(DateUnit.DAY) case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 159328cc0d95..ea7698fcf344 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -792,6 +792,73 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(df, json, "binaryData.json") } + test("date type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "date", + | "type" : { + | "name" : "date", + | "unit" : "DAY" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 32 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "date", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ -1, -1, 16533, 16930 ] + | } ] + | } ] + |} + """.stripMargin + + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val d1 = new Date(-1/*sdf.parse("1969-12-31 13:10:15.000 UTC").getTime*/) + val d2 = new Date(0/*sdf.parse("1969-12-31 13:10:15.000 UTC").getTime*/) + val d3 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) + val d4 = new Date(sdf.parse("2016-05-09 12:01:01.000 UTC").getTime) + + val df = Seq(d1, d2, d3, d4).toDF("date") + df.show() + println(s"date: $d2, ${d2.getTime}, ${d2.toGMTString}, ${d1.toGMTString}") + + collectAndValidate(df, json, "dateData.json") + } + + ignore("timestamp conversion") { + /*val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) + val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) + val data = Seq(ts1, ts2) + + val schema = new JSONSchema(Seq(new TimestampType("timestamp"))) + val us_data = data.map(_.getTime * 1000) // convert to microseconds + val columns = Seq( + new PrimitiveColumn("timestamp", data.length, data.map(_ => true), us_data)) + val batch = new JSONRecordBatch(data.length, columns) + val json = new JSONFile(schema, Seq(batch)) + + val df = data.toDF("timestamp") + + collectAndValidate(df, json, "timestampData.json")*/ + } + test("floating-point NaN") { val json = s""" @@ -1044,10 +1111,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { runUnsupported { complexData.toArrowPayload.collect() } val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) - val d1 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) - val d2 = new Date(sdf.parse("2016-05-09 13:10:15.000 UTC").getTime) - runUnsupported { Seq(d1, d2).toDF("date").toArrowPayload.collect() } - val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) runUnsupported { Seq(ts1, ts2).toDF("timestamp").toArrowPayload.collect() } From 20313f92758e5639b309ba810945a8415941ef86 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 17 Jul 2017 17:42:15 -0700 Subject: [PATCH 02/34] DateTimeUtils forces defaultTimeZone --- .../execution/arrow/ArrowConvertersSuite.scala | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index ea7698fcf344..79839f267b08 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -31,6 +31,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, StructField, StructType} import org.apache.spark.util.Utils @@ -822,27 +823,29 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "name" : "date", | "count" : 4, | "VALIDITY" : [ 1, 1, 1, 1 ], - | "DATA" : [ -1, -1, 16533, 16930 ] + | "DATA" : [ -1, 0, 16533, 16930 ] | } ] | } ] |} """.stripMargin val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) - val d1 = new Date(-1/*sdf.parse("1969-12-31 13:10:15.000 UTC").getTime*/) - val d2 = new Date(0/*sdf.parse("1969-12-31 13:10:15.000 UTC").getTime*/) + val d1 = new Date(-1) // "1969-12-31 13:10:15.000 UTC" + val d2 = new Date(0) // "1970-01-01 13:10:15.000 UTC" val d3 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) val d4 = new Date(sdf.parse("2016-05-09 12:01:01.000 UTC").getTime) + // Date is created unaware of timezone, but DateTimeUtils force defaultTimeZone() + assert(DateTimeUtils.toJavaDate(DateTimeUtils.fromJavaDate(d2)).getTime == d2.getTime) + val df = Seq(d1, d2, d3, d4).toDF("date") - df.show() - println(s"date: $d2, ${d2.getTime}, ${d2.toGMTString}, ${d1.toGMTString}") collectAndValidate(df, json, "dateData.json") } ignore("timestamp conversion") { - /*val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + /* + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) val data = Seq(ts1, ts2) @@ -856,7 +859,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val df = data.toDF("timestamp") - collectAndValidate(df, json, "timestampData.json")*/ + collectAndValidate(df, json, "timestampData.json") + */ } test("floating-point NaN") { From 69e1e21bf4bebc7bea6bd9322e4300df71a90b18 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 17 Jul 2017 17:48:47 -0700 Subject: [PATCH 03/34] fix style checks --- .../org/apache/spark/sql/execution/arrow/ArrowConverters.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 7419f9fcda68..2b0103f33007 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -21,6 +21,7 @@ import java.io.ByteArrayOutputStream import java.nio.channels.Channels import scala.collection.JavaConverters._ + import io.netty.buffer.ArrowBuf import org.apache.arrow.memory.{BufferAllocator, RootAllocator} import org.apache.arrow.vector._ @@ -30,6 +31,7 @@ import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ import org.apache.spark.util.Utils From dbfbef3b6318d8715cd72cad5913f05cb4c43aaf Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 18 Jul 2017 15:02:13 -0700 Subject: [PATCH 04/34] date type java tests passing --- .../sql/execution/arrow/ArrowConvertersSuite.scala | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 79839f267b08..c181102515ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -823,20 +823,17 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "name" : "date", | "count" : 4, | "VALIDITY" : [ 1, 1, 1, 1 ], - | "DATA" : [ -1, 0, 16533, 16930 ] + | "DATA" : [ -1, 0, 16533, 382607 ] | } ] | } ] |} """.stripMargin val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) - val d1 = new Date(-1) // "1969-12-31 13:10:15.000 UTC" - val d2 = new Date(0) // "1970-01-01 13:10:15.000 UTC" + val d1 = DateTimeUtils.toJavaDate(-1) // "1969-12-31" + val d2 = DateTimeUtils.toJavaDate(0) // "1970-01-01" val d3 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) - val d4 = new Date(sdf.parse("2016-05-09 12:01:01.000 UTC").getTime) - - // Date is created unaware of timezone, but DateTimeUtils force defaultTimeZone() - assert(DateTimeUtils.toJavaDate(DateTimeUtils.fromJavaDate(d2)).getTime == d2.getTime) + val d4 = new Date(sdf.parse("3017-07-18 14:55:00.000 UTC").getTime) val df = Seq(d1, d2, d3, d4).toDF("date") From 436afff95620ac91c731f172b493396b0ba2a8d4 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 18 Jul 2017 15:48:12 -0700 Subject: [PATCH 05/34] timestamp type java tests passing --- .../sql/execution/arrow/ArrowConverters.scala | 3 +- .../arrow/ArrowConvertersSuite.scala | 54 ++++++++++++++----- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 2b0103f33007..24448ae651e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -28,7 +28,7 @@ import org.apache.arrow.vector._ import org.apache.arrow.vector.BaseValueVector.BaseMutator import org.apache.arrow.vector.file._ import org.apache.arrow.vector.schema.{ArrowFieldNode, ArrowRecordBatch} -import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision} +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel @@ -85,6 +85,7 @@ private[sql] object ArrowConverters { case StringType => ArrowType.Utf8.INSTANCE case BinaryType => ArrowType.Binary.INSTANCE case DateType => new ArrowType.Date(DateUnit.DAY) + case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dataType") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index c181102515ea..9be63da7c134 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -840,24 +840,52 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { collectAndValidate(df, json, "dateData.json") } - ignore("timestamp conversion") { - /* - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) - val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) - val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) - val data = Seq(ts1, ts2) + test("timestamp type conversion") { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "timestamp", + | "type" : { + | "name" : "timestamp", + | "unit" : "MICROSECOND" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "timestamp", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ -1234, 0, 1365383415567000, 33057298500000000 ] + | } ] + | } ] + |} + """.stripMargin - val schema = new JSONSchema(Seq(new TimestampType("timestamp"))) - val us_data = data.map(_.getTime * 1000) // convert to microseconds - val columns = Seq( - new PrimitiveColumn("timestamp", data.length, data.map(_ => true), us_data)) - val batch = new JSONRecordBatch(data.length, columns) - val json = new JSONFile(schema, Seq(batch)) + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val ts1 = DateTimeUtils.toJavaTimestamp(-1234L) + val ts2 = DateTimeUtils.toJavaTimestamp(0L) + val ts3 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) + val ts4 = new Timestamp(sdf.parse("3017-07-18 14:55:00.000 UTC").getTime) + val data = Seq(ts1, ts2, ts3, ts4) val df = data.toDF("timestamp") collectAndValidate(df, json, "timestampData.json") - */ } test("floating-point NaN") { From 78119ca1a4b9b554246d7ced2c669f294c272165 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 19 Jul 2017 15:46:19 -0700 Subject: [PATCH 06/34] adding date and timestamp data to python tests, not passing --- python/pyspark/sql/dataframe.py | 2 ++ python/pyspark/sql/tests.py | 11 +++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 944739bcd207..e6172d77f65c 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1803,6 +1803,8 @@ def _to_corrected_pandas_type(dt): return np.int32 elif type(dt) == FloatType: return np.float32 + elif type(dt) == DateType or type(dt) == TimestampType: + return 'datetime64[ns]' else: return None diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 29e48a6ccf76..4dfa79c0ab11 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -2882,6 +2882,7 @@ class ArrowTests(ReusedPySparkTestCase): @classmethod def setUpClass(cls): + from datetime import datetime ReusedPySparkTestCase.setUpClass() cls.spark = SparkSession(cls.sc) cls.spark.conf.set("spark.sql.execution.arrow.enable", "true") @@ -2890,10 +2891,12 @@ def setUpClass(cls): StructField("2_int_t", IntegerType(), True), StructField("3_long_t", LongType(), True), StructField("4_float_t", FloatType(), True), - StructField("5_double_t", DoubleType(), True)]) - cls.data = [("a", 1, 10, 0.2, 2.0), - ("b", 2, 20, 0.4, 4.0), - ("c", 3, 30, 0.8, 6.0)] + StructField("5_double_t", DoubleType(), True), + StructField("6_date_t", DateType(), True), + StructField("7_timestamp_t", TimestampType(), True)]) + cls.data = [("a", 1, 10, 0.2, 2.0, datetime(2011, 1, 1), datetime(1970, 1, 1, 0, 0, 0)), + ("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + ("c", 3, 30, 0.8, 6.0, datetime(2013, 3, 3), datetime(2013, 3, 3, 3, 3, 3))] def assertFramesEqual(self, df_with_arrow, df_without): msg = ("DataFrame from Arrow is not equal" + From b709d78c03701f92f617651879ee33dada0c4da1 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 19 Jul 2017 16:17:57 -0700 Subject: [PATCH 07/34] TimestampType is correctly inferred as datetime64[ns] --- python/pyspark/sql/dataframe.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index e6172d77f65c..bbb66a8a6069 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1793,6 +1793,7 @@ def _to_corrected_pandas_type(dt): """ When converting Spark SQL records to Pandas DataFrame, the inferred data type may be wrong. This method gets the corrected data type for Pandas if that type may be inferred uncorrectly. + NOTE: DateType is inferred incorrectly as 'object', TimestampType is correct with datetime64[ns] """ import numpy as np if type(dt) == ByteType: @@ -1803,7 +1804,7 @@ def _to_corrected_pandas_type(dt): return np.int32 elif type(dt) == FloatType: return np.float32 - elif type(dt) == DateType or type(dt) == TimestampType: + elif type(dt) == DateType: return 'datetime64[ns]' else: return None From e6d8590197c449c17512bf919f64a0426609218a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 24 Jul 2017 15:14:35 -0700 Subject: [PATCH 08/34] Adding DateType and TimestampType to ArrowUtils conversions --- .../org/apache/spark/sql/execution/arrow/ArrowUtils.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 2caf1ef02909..e71437b42c80 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.arrow import scala.collection.JavaConverters._ import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.types.FloatingPointPrecision +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.spark.sql.types._ @@ -42,6 +42,8 @@ object ArrowUtils { case StringType => ArrowType.Utf8.INSTANCE case BinaryType => ArrowType.Binary.INSTANCE case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) + case DateType => new ArrowType.Date(DateUnit.DAY) + case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") } @@ -58,6 +60,8 @@ object ArrowUtils { case ArrowType.Utf8.INSTANCE => StringType case ArrowType.Binary.INSTANCE => BinaryType case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) + case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType + case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt") } From 719e77ce258c0a462e317f4d6d9de925a4816771 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 24 Jul 2017 16:36:07 -0700 Subject: [PATCH 09/34] using default timezone, fixed tests --- python/pyspark/sql/tests.py | 10 ++++++++-- .../apache/spark/sql/execution/arrow/ArrowUtils.scala | 4 +++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 67c81568a0a5..2d5c6b89232f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3010,9 +3010,9 @@ def setUpClass(cls): StructField("5_double_t", DoubleType(), True), StructField("6_date_t", DateType(), True), StructField("7_timestamp_t", TimestampType(), True)]) - cls.data = [("a", 1, 10, 0.2, 2.0, datetime(2011, 1, 1), datetime(1970, 1, 1, 0, 0, 0)), + cls.data = [("a", 1, 10, 0.2, 2.0, datetime(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), ("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), - ("c", 3, 30, 0.8, 6.0, datetime(2013, 3, 3), datetime(2013, 3, 3, 3, 3, 3))] + ("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] def assertFramesEqual(self, df_with_arrow, df_without): msg = ("DataFrame from Arrow is not equal" + @@ -3039,6 +3039,9 @@ def test_toPandas_arrow_toggle(self): pdf = df.toPandas() self.spark.conf.set("spark.sql.execution.arrow.enable", "true") pdf_arrow = df.toPandas() + # need to remove timezone for comparison + pdf_arrow["7_timestamp_t"] = \ + pdf_arrow["7_timestamp_t"].apply(lambda ts: ts.tz_localize(None)) self.assertFramesEqual(pdf_arrow, pdf) def test_pandas_round_trip(self): @@ -3053,6 +3056,9 @@ def test_pandas_round_trip(self): pdf = pd.DataFrame(data=data_dict) df = self.spark.createDataFrame(self.data, schema=self.schema) pdf_arrow = df.toPandas() + # need to remove timezone for comparison + pdf_arrow["7_timestamp_t"] = \ + pdf_arrow["7_timestamp_t"].apply(lambda ts: ts.tz_localize(None)) self.assertFramesEqual(pdf_arrow, pdf) def test_filtered_frame(self): diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index e71437b42c80..666ad616faed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -23,6 +23,7 @@ import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ object ArrowUtils { @@ -43,7 +44,8 @@ object ArrowUtils { case BinaryType => ArrowType.Binary.INSTANCE case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) case DateType => new ArrowType.Date(DateUnit.DAY) - case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) + case TimestampType => + new ArrowType.Timestamp(TimeUnit.MICROSECOND, DateTimeUtils.defaultTimeZone().getID) case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") } From 3585520c3524d7a6429a7400f07c74b466c9a229 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 25 Jul 2017 10:21:32 -0700 Subject: [PATCH 10/34] fixed scala tests for timestamp --- .../spark/sql/execution/arrow/ArrowConvertersSuite.scala | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 96f96304c08b..bed8e56e3209 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -849,7 +849,8 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "name" : "timestamp", | "type" : { | "name" : "timestamp", - | "unit" : "MICROSECOND" + | "unit" : "MICROSECOND", + | "timezone" : "${DateTimeUtils.defaultTimeZone().getID}" | }, | "nullable" : true, | "children" : [ ], @@ -1138,11 +1139,6 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { runUnsupported { arrayData.toDF().toArrowPayload.collect() } runUnsupported { mapData.toDF().toArrowPayload.collect() } runUnsupported { complexData.toArrowPayload.collect() } - - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) - val ts1 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) - val ts2 = new Timestamp(sdf.parse("2013-04-08 13:10:10.789 UTC").getTime) - runUnsupported { Seq(ts1, ts2).toDF("timestamp").toArrowPayload.collect() } } test("test Arrow Validator") { From f977d0bab1ffaa8d18bacb4127a0e5219df2072a Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 26 Jul 2017 11:14:56 -0700 Subject: [PATCH 11/34] Adding sync between Python and Java default timezones --- python/pyspark/sql/tests.py | 15 +++ .../sql/catalyst/util/DateTimeTestUtils.scala | 7 ++ .../arrow/ArrowConvertersSuite.scala | 98 ++++++++++--------- 3 files changed, 72 insertions(+), 48 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2d5c6b89232f..c95acce658ab 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3000,6 +3000,14 @@ class ArrowTests(ReusedPySparkTestCase): def setUpClass(cls): from datetime import datetime ReusedPySparkTestCase.setUpClass() + + # Synchronize default timezone between Python and Java + tz = "America/Los_Angeles" + os.environ["TZ"] = tz + time.tzset() + cls.old_tz = cls.sc._jvm.org.apache.spark.sql.catalyst.util.DateTimeTestUtils\ + .setDefaultTimeZone(tz) + cls.spark = SparkSession(cls.sc) cls.spark.conf.set("spark.sql.execution.arrow.enable", "true") cls.schema = StructType([ @@ -3014,6 +3022,13 @@ def setUpClass(cls): ("b", 2, 20, 0.4, 4.0, datetime(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), ("c", 3, 30, 0.8, 6.0, datetime(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3))] + @classmethod + def tearDownClass(cls): + del os.environ["TZ"] + time.tzset() + cls.sc._jvm.org.apache.spark.sql.catalyst.util.DateTimeTestUtils\ + .setDefaultTimeZone(cls.old_tz) + def assertFramesEqual(self, df_with_arrow, df_without): msg = ("DataFrame from Arrow is not equal" + ("\n\nWith Arrow:\n%s\n%s" % (df_with_arrow, df_with_arrow.dtypes)) + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala index 0c1feb3aa088..be57c990bf13 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala @@ -37,4 +37,11 @@ object DateTimeTestUtils { DateTimeUtils.resetThreadLocals() } } + + def setDefaultTimeZone(id: String): String = { + val originalDefaultTimeZone = DateTimeUtils.defaultTimeZone().getID + DateTimeUtils.resetThreadLocals() + TimeZone.setDefault(TimeZone.getTimeZone(id)) + originalDefaultTimeZone + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index bed8e56e3209..8b0d5c2e0f57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -20,7 +20,7 @@ import java.io.File import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.Locale +import java.util.{Locale, TimeZone} import com.google.common.io.Files import org.apache.arrow.memory.RootAllocator @@ -31,7 +31,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, StructField, StructType} import org.apache.spark.util.Utils @@ -841,52 +841,54 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } test("timestamp type conversion") { - val json = - s""" - |{ - | "schema" : { - | "fields" : [ { - | "name" : "timestamp", - | "type" : { - | "name" : "timestamp", - | "unit" : "MICROSECOND", - | "timezone" : "${DateTimeUtils.defaultTimeZone().getID}" - | }, - | "nullable" : true, - | "children" : [ ], - | "typeLayout" : { - | "vectors" : [ { - | "type" : "VALIDITY", - | "typeBitWidth" : 1 - | }, { - | "type" : "DATA", - | "typeBitWidth" : 64 - | } ] - | } - | } ] - | }, - | "batches" : [ { - | "count" : 4, - | "columns" : [ { - | "name" : "timestamp", - | "count" : 4, - | "VALIDITY" : [ 1, 1, 1, 1 ], - | "DATA" : [ -1234, 0, 1365383415567000, 33057298500000000 ] - | } ] - | } ] - |} - """.stripMargin - - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) - val ts1 = DateTimeUtils.toJavaTimestamp(-1234L) - val ts2 = DateTimeUtils.toJavaTimestamp(0L) - val ts3 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) - val ts4 = new Timestamp(sdf.parse("3017-07-18 14:55:00.000 UTC").getTime) - val data = Seq(ts1, ts2, ts3, ts4) - - val df = data.toDF("timestamp") - - collectAndValidate(df, json, "timestampData.json") + DateTimeTestUtils.withDefaultTimeZone(TimeZone.getTimeZone("America/Los_Angeles")) { + val json = + s""" + |{ + | "schema" : { + | "fields" : [ { + | "name" : "timestamp", + | "type" : { + | "name" : "timestamp", + | "unit" : "MICROSECOND", + | "timezone" : "${DateTimeUtils.defaultTimeZone().getID}" + | }, + | "nullable" : true, + | "children" : [ ], + | "typeLayout" : { + | "vectors" : [ { + | "type" : "VALIDITY", + | "typeBitWidth" : 1 + | }, { + | "type" : "DATA", + | "typeBitWidth" : 64 + | } ] + | } + | } ] + | }, + | "batches" : [ { + | "count" : 4, + | "columns" : [ { + | "name" : "timestamp", + | "count" : 4, + | "VALIDITY" : [ 1, 1, 1, 1 ], + | "DATA" : [ -1234, 0, 1365383415567000, 33057298500000000 ] + | } ] + | } ] + |} + """.stripMargin + + val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) + val ts1 = DateTimeUtils.toJavaTimestamp(-1234L) + val ts2 = DateTimeUtils.toJavaTimestamp(0L) + val ts3 = new Timestamp(sdf.parse("2013-04-08 01:10:15.567 UTC").getTime) + val ts4 = new Timestamp(sdf.parse("3017-07-18 14:55:00.000 UTC").getTime) + val data = Seq(ts1, ts2, ts3, ts4) + + val df = data.toDF("timestamp") + + collectAndValidate(df, json, "timestampData.json") + } } test("floating-point NaN") { From 3b83d7acf17433b1f5581f0b8b87c54a91309839 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 27 Jul 2017 11:07:26 -0700 Subject: [PATCH 12/34] added date timestamp writers, fixed tests --- python/pyspark/sql/tests.py | 4 +-- .../sql/execution/arrow/ArrowWriter.scala | 34 ++++++++++++++++--- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 4510eba13de8..9b89265c4fce 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3036,8 +3036,8 @@ def assertFramesEqual(self, df_with_arrow, df_without): self.assertTrue(df_without.equals(df_with_arrow), msg=msg) def test_unsupported_datatype(self): - schema = StructType([StructField("dt", DateType(), True)]) - df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) + schema = StructType([StructField("decimal", DecimalType(), True)]) + df = self.spark.createDataFrame([(None,)], schema=schema) with QuietTest(self.sc): self.assertRaises(Exception, lambda: df.toPandas()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 11ba04d2ce9a..95d4f71180c0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -21,7 +21,6 @@ import scala.collection.JavaConverters._ import org.apache.arrow.vector._ import org.apache.arrow.vector.complex._ -import org.apache.arrow.vector.util.DecimalUtility import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters @@ -55,6 +54,8 @@ object ArrowWriter { case (DoubleType, vector: NullableFloat8Vector) => new DoubleWriter(vector) case (StringType, vector: NullableVarCharVector) => new StringWriter(vector) case (BinaryType, vector: NullableVarBinaryVector) => new BinaryWriter(vector) + case (DateType, vector: NullableDateDayVector) => new DateWriter(vector) + case (TimestampType, vector: NullableTimeStampMicroTZVector) => new TimestampWriter(vector) case (ArrayType(_, _), vector: ListVector) => val elementVector = createFieldWriter(vector.getDataVector()) new ArrayWriter(vector, elementVector) @@ -69,9 +70,7 @@ object ArrowWriter { } } -class ArrowWriter( - val root: VectorSchemaRoot, - fields: Array[ArrowFieldWriter]) { +class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) { def schema: StructType = StructType(fields.map { f => StructField(f.name, f.dataType, f.nullable) @@ -254,6 +253,33 @@ private[arrow] class BinaryWriter( } } +private[arrow] class DateWriter(val valueVector: NullableDateDayVector) extends ArrowFieldWriter { + + override def valueMutator: NullableDateDayVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getInt(ordinal)) + } +} + +private[arrow] class TimestampWriter( + val valueVector: NullableTimeStampMicroTZVector) extends ArrowFieldWriter { + + override def valueMutator: NullableTimeStampMicroTZVector#Mutator = valueVector.getMutator() + + override def setNull(): Unit = { + valueMutator.setNull(count) + } + + override def setValue(input: SpecializedGetters, ordinal: Int): Unit = { + valueMutator.setSafe(count, input.getLong(ordinal)) + } +} + private[arrow] class ArrayWriter( val valueVector: ListVector, val elementWriter: ArrowFieldWriter) extends ArrowFieldWriter { From a6009a5f06011718b02d491edf8d88d4248c8f0b Mon Sep 17 00:00:00 2001 From: Takuya UESHIN Date: Fri, 28 Jul 2017 13:31:51 +0900 Subject: [PATCH 13/34] Modify ArrowUtils to have timeZoneId when convert schema to Arrow schema. Modify ArrowConverters to use session local timezone. Fix tests. closes #25 --- python/pyspark/sql/tests.py | 7 +++--- .../sql/catalyst/util/DateTimeTestUtils.scala | 7 ------ .../scala/org/apache/spark/sql/Dataset.scala | 4 +++- .../sql/execution/arrow/ArrowConverters.scala | 3 ++- .../sql/execution/arrow/ArrowUtils.scala | 20 +++++++++-------- .../arrow/ArrowConvertersSuite.scala | 21 ++++++++++-------- .../sql/execution/arrow/ArrowUtilsSuite.scala | 22 +++++++++++++++++++ 7 files changed, 53 insertions(+), 31 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9b89265c4fce..d4a439d9eb75 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3005,10 +3005,9 @@ def setUpClass(cls): tz = "America/Los_Angeles" os.environ["TZ"] = tz time.tzset() - cls.old_tz = cls.sc._jvm.org.apache.spark.sql.catalyst.util.DateTimeTestUtils\ - .setDefaultTimeZone(tz) cls.spark = SparkSession(cls.sc) + cls.spark.conf.set("spark.sql.session.timeZone", tz) cls.spark.conf.set("spark.sql.execution.arrow.enable", "true") cls.schema = StructType([ StructField("1_str_t", StringType(), True), @@ -3026,8 +3025,8 @@ def setUpClass(cls): def tearDownClass(cls): del os.environ["TZ"] time.tzset() - cls.sc._jvm.org.apache.spark.sql.catalyst.util.DateTimeTestUtils\ - .setDefaultTimeZone(cls.old_tz) + + cls.spark.stop() def assertFramesEqual(self, df_with_arrow, df_without): msg = ("DataFrame from Arrow is not equal" + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala index be57c990bf13..0c1feb3aa088 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeTestUtils.scala @@ -37,11 +37,4 @@ object DateTimeTestUtils { DateTimeUtils.resetThreadLocals() } } - - def setDefaultTimeZone(id: String): String = { - val originalDefaultTimeZone = DateTimeUtils.defaultTimeZone().getID - DateTimeUtils.resetThreadLocals() - TimeZone.setDefault(TimeZone.getTimeZone(id)) - originalDefaultTimeZone - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 9007367f5aa8..b22454385fb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3090,9 +3090,11 @@ class Dataset[T] private[sql]( private[sql] def toArrowPayload: RDD[ArrowPayload] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch + val timeZoneId = Option(sparkSession.sessionState.conf.sessionLocalTimeZone) queryExecution.toRdd.mapPartitionsInternal { iter => val context = TaskContext.get() - ArrowConverters.toPayloadIterator(iter, schemaCaptured, maxRecordsPerBatch, context) + ArrowConverters.toPayloadIterator( + iter, schemaCaptured, maxRecordsPerBatch, timeZoneId, context) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 240f38f5bfeb..cd565909bbcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -60,9 +60,10 @@ private[sql] object ArrowConverters { rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Int, + timeZoneId: Option[String], context: TaskContext): Iterator[ArrowPayload] = { - val arrowSchema = ArrowUtils.toArrowSchema(schema) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val allocator = ArrowUtils.rootAllocator.newChildAllocator("toPayloadIterator", 0, Long.MaxValue) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index 666ad616faed..e44d960210a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -32,7 +32,7 @@ object ArrowUtils { // todo: support more types. - def toArrowType(dt: DataType): ArrowType = dt match { + def toArrowType(dt: DataType, timeZoneId: Option[String] = None): ArrowType = dt match { case BooleanType => ArrowType.Bool.INSTANCE case ByteType => new ArrowType.Int(8, true) case ShortType => new ArrowType.Int(8 * 2, true) @@ -44,8 +44,8 @@ object ArrowUtils { case BinaryType => ArrowType.Binary.INSTANCE case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) case DateType => new ArrowType.Date(DateUnit.DAY) - case TimestampType => - new ArrowType.Timestamp(TimeUnit.MICROSECOND, DateTimeUtils.defaultTimeZone().getID) + case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, + timeZoneId.getOrElse(DateTimeUtils.defaultTimeZone().getID)) case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") } @@ -67,19 +67,21 @@ object ArrowUtils { case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt") } - def toArrowField(name: String, dt: DataType, nullable: Boolean): Field = { + def toArrowField( + name: String, dt: DataType, nullable: Boolean, timeZoneId: Option[String] = None): Field = { dt match { case ArrayType(elementType, containsNull) => val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) - new Field(name, fieldType, Seq(toArrowField("element", elementType, containsNull)).asJava) + new Field(name, fieldType, + Seq(toArrowField("element", elementType, containsNull, timeZoneId)).asJava) case StructType(fields) => val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) new Field(name, fieldType, fields.map { field => - toArrowField(field.name, field.dataType, field.nullable) + toArrowField(field.name, field.dataType, field.nullable, timeZoneId) }.toSeq.asJava) case dataType => - val fieldType = new FieldType(nullable, toArrowType(dataType), null) + val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId), null) new Field(name, fieldType, Seq.empty[Field].asJava) } } @@ -100,9 +102,9 @@ object ArrowUtils { } } - def toArrowSchema(schema: StructType): Schema = { + def toArrowSchema(schema: StructType, timeZoneId: Option[String] = None): Schema = { new Schema(schema.map { field => - toArrowField(field.name, field.dataType, field.nullable) + toArrowField(field.name, field.dataType, field.nullable, timeZoneId) }.asJava) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 6cdffc7b4c01..a2ae55c8a4f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -20,7 +20,7 @@ import java.io.File import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat -import java.util.{Locale, TimeZone} +import java.util.Locale import com.google.common.io.Files import org.apache.arrow.memory.RootAllocator @@ -31,7 +31,8 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkException import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -841,7 +842,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } test("timestamp type conversion") { - DateTimeTestUtils.withDefaultTimeZone(TimeZone.getTimeZone("America/Los_Angeles")) { + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "America/Los_Angeles") { val json = s""" |{ @@ -851,7 +852,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { | "type" : { | "name" : "timestamp", | "unit" : "MICROSECOND", - | "timezone" : "${DateTimeUtils.defaultTimeZone().getID}" + | "timezone" : "America/Los_Angeles" | }, | "nullable" : true, | "children" : [ ], @@ -887,7 +888,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val df = data.toDF("timestamp") - collectAndValidate(df, json, "timestampData.json") + collectAndValidate(df, json, "timestampData.json", Option("America/Los_Angeles")) } } @@ -1720,22 +1721,24 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { } /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ - private def collectAndValidate(df: DataFrame, json: String, file: String): Unit = { + private def collectAndValidate( + df: DataFrame, json: String, file: String, timeZoneId: Option[String] = None): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator val arrowPayload = df.coalesce(1).toArrowPayload.collect().head val tempFile = new File(tempDataPath, file) Files.write(json, tempFile, StandardCharsets.UTF_8) - validateConversion(df.schema, arrowPayload, tempFile) + validateConversion(df.schema, arrowPayload, tempFile, timeZoneId) } private def validateConversion( sparkSchema: StructType, arrowPayload: ArrowPayload, - jsonFile: File): Unit = { + jsonFile: File, + timeZoneId: Option[String] = None): Unit = { val allocator = new RootAllocator(Long.MaxValue) val jsonReader = new JsonFileReader(jsonFile, allocator) - val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema) + val arrowSchema = ArrowUtils.toArrowSchema(sparkSchema, timeZoneId) val jsonSchema = jsonReader.start() Validator.compareSchemas(arrowSchema, jsonSchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala index 638619fd39d0..d67e3b69a775 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.execution.arrow +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ class ArrowUtilsSuite extends SparkFunSuite { @@ -42,6 +45,25 @@ class ArrowUtilsSuite extends SparkFunSuite { roundtrip(StringType) roundtrip(BinaryType) roundtrip(DecimalType.SYSTEM_DEFAULT) + roundtrip(DateType) + } + + test("timestamp") { + val schema = new StructType().add("value", TimestampType) + val arrowSchema = ArrowUtils.toArrowSchema(schema) + val fieldType = arrowSchema.findField("value").getType.asInstanceOf[ArrowType.Timestamp] + assert(fieldType.getTimezone() === DateTimeUtils.defaultTimeZone().getID) + assert(ArrowUtils.fromArrowSchema(arrowSchema) === schema) + + def roundtripWithTz(timeZoneId: String): Unit = { + val arrowSchema = ArrowUtils.toArrowSchema(schema, Option(timeZoneId)) + val fieldType = arrowSchema.findField("value").getType.asInstanceOf[ArrowType.Timestamp] + assert(fieldType.getTimezone() === timeZoneId) + assert(ArrowUtils.fromArrowSchema(arrowSchema) === schema) + } + roundtripWithTz("Asia/Tokyo") + roundtripWithTz("UTC") + roundtripWithTz("America/Los_Angeles") } test("array") { From 2ec98cc32ae1587a31bf9e5a3b80f2e078ce10e5 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 1 Aug 2017 15:29:35 -0700 Subject: [PATCH 14/34] fixed python test tearDownClass --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index d4a439d9eb75..f284878cd0ad 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3025,7 +3025,7 @@ def setUpClass(cls): def tearDownClass(cls): del os.environ["TZ"] time.tzset() - + ReusedPySparkTestCase.tearDownClass() cls.spark.stop() def assertFramesEqual(self, df_with_arrow, df_without): From c29018c529aefa6842a40c44de6d7e452e3d30e8 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 2 Aug 2017 10:42:25 -0700 Subject: [PATCH 15/34] using Date.valueOf for tests instead --- .../spark/sql/execution/arrow/ArrowConvertersSuite.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index a2ae55c8a4f1..7d61426a33b3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -830,11 +830,10 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { |} """.stripMargin - val sdf = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS z", Locale.US) val d1 = DateTimeUtils.toJavaDate(-1) // "1969-12-31" val d2 = DateTimeUtils.toJavaDate(0) // "1970-01-01" - val d3 = new Date(sdf.parse("2015-04-08 13:10:15.000 UTC").getTime) - val d4 = new Date(sdf.parse("3017-07-18 14:55:00.000 UTC").getTime) + val d3 = Date.valueOf("2015-04-08") + val d4 = Date.valueOf("3017-07-18") val df = Seq(d1, d2, d3, d4).toDF("date") From 7dbdb1fb309ac1cb9d3085cfe9c78c647d4ace73 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 14 Aug 2017 14:44:28 -0700 Subject: [PATCH 16/34] Made timezone id required for TimestampType --- .../sql/execution/arrow/ArrowUtils.scala | 17 +++++++++----- .../sql/execution/arrow/ArrowWriter.scala | 4 ++-- .../sql/execution/arrow/ArrowUtilsSuite.scala | 10 ++++----- .../execution/arrow/ArrowWriterSuite.scala | 16 +++++++------- .../vectorized/ArrowColumnVectorSuite.scala | 22 +++++++++---------- 5 files changed, 37 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index e44d960210a6..bcd69ce1c912 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -32,7 +32,8 @@ object ArrowUtils { // todo: support more types. - def toArrowType(dt: DataType, timeZoneId: Option[String] = None): ArrowType = dt match { + /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ + def toArrowType(dt: DataType, timeZoneId: Option[String]): ArrowType = dt match { case BooleanType => ArrowType.Bool.INSTANCE case ByteType => new ArrowType.Int(8, true) case ShortType => new ArrowType.Int(8 * 2, true) @@ -44,8 +45,12 @@ object ArrowUtils { case BinaryType => ArrowType.Binary.INSTANCE case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) case DateType => new ArrowType.Date(DateUnit.DAY) - case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, - timeZoneId.getOrElse(DateTimeUtils.defaultTimeZone().getID)) + case TimestampType => + timeZoneId match { + case Some(id) => new ArrowType.Timestamp(TimeUnit.MICROSECOND, id) + case None => + throw new UnsupportedOperationException("TimestampType must supply timeZoneId parameter") + } case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") } @@ -67,8 +72,9 @@ object ArrowUtils { case _ => throw new UnsupportedOperationException(s"Unsupported data type: $dt") } + /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ def toArrowField( - name: String, dt: DataType, nullable: Boolean, timeZoneId: Option[String] = None): Field = { + name: String, dt: DataType, nullable: Boolean, timeZoneId: Option[String]): Field = { dt match { case ArrayType(elementType, containsNull) => val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) @@ -102,7 +108,8 @@ object ArrowUtils { } } - def toArrowSchema(schema: StructType, timeZoneId: Option[String] = None): Schema = { + /** Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType */ + def toArrowSchema(schema: StructType, timeZoneId: Option[String]): Schema = { new Schema(schema.map { field => toArrowField(field.name, field.dataType, field.nullable, timeZoneId) }.asJava) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 95d4f71180c0..331a26071cb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.types._ object ArrowWriter { - def create(schema: StructType): ArrowWriter = { - val arrowSchema = ArrowUtils.toArrowSchema(schema) + def create(schema: StructType, timeZoneId: Option[String]): ArrowWriter = { + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) create(root) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala index d67e3b69a775..89285d83d031 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala @@ -28,7 +28,7 @@ class ArrowUtilsSuite extends SparkFunSuite { def roundtrip(dt: DataType): Unit = { dt match { case schema: StructType => - assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema)) === schema) + assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, None)) === schema) case _ => roundtrip(new StructType().add("value", dt)) } @@ -49,18 +49,16 @@ class ArrowUtilsSuite extends SparkFunSuite { } test("timestamp") { - val schema = new StructType().add("value", TimestampType) - val arrowSchema = ArrowUtils.toArrowSchema(schema) - val fieldType = arrowSchema.findField("value").getType.asInstanceOf[ArrowType.Timestamp] - assert(fieldType.getTimezone() === DateTimeUtils.defaultTimeZone().getID) - assert(ArrowUtils.fromArrowSchema(arrowSchema) === schema) def roundtripWithTz(timeZoneId: String): Unit = { + val schema = new StructType().add("value", TimestampType) val arrowSchema = ArrowUtils.toArrowSchema(schema, Option(timeZoneId)) val fieldType = arrowSchema.findField("value").getType.asInstanceOf[ArrowType.Timestamp] assert(fieldType.getTimezone() === timeZoneId) assert(ArrowUtils.fromArrowSchema(arrowSchema) === schema) } + + roundtripWithTz(DateTimeUtils.defaultTimeZone().getID) roundtripWithTz("Asia/Tokyo") roundtripWithTz("UTC") roundtripWithTz("America/Los_Angeles") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index e9a629315f5f..006bf6b43061 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -27,9 +27,9 @@ import org.apache.spark.unsafe.types.UTF8String class ArrowWriterSuite extends SparkFunSuite { test("simple") { - def check(dt: DataType, data: Seq[Any]): Unit = { + def check(dt: DataType, data: Seq[Any], timeZoneId: Option[String] = None): Unit = { val schema = new StructType().add("value", dt, nullable = true) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, timeZoneId) assert(writer.schema === schema) data.foreach { datum => @@ -69,9 +69,9 @@ class ArrowWriterSuite extends SparkFunSuite { } test("get multiple") { - def check(dt: DataType, data: Seq[Any]): Unit = { + def check(dt: DataType, data: Seq[Any], timeZoneId: Option[String] = None): Unit = { val schema = new StructType().add("value", dt, nullable = false) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, timeZoneId) assert(writer.schema === schema) data.foreach { datum => @@ -105,7 +105,7 @@ class ArrowWriterSuite extends SparkFunSuite { test("array") { val schema = new StructType() .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, None) assert(writer.schema === schema) writer.write(InternalRow(ArrayData.toArrayData(Array(1, 2, 3)))) @@ -144,7 +144,7 @@ class ArrowWriterSuite extends SparkFunSuite { test("nested array") { val schema = new StructType().add("nested", ArrayType(ArrayType(IntegerType))) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, None) assert(writer.schema === schema) writer.write(InternalRow(ArrayData.toArrayData(Array( @@ -195,7 +195,7 @@ class ArrowWriterSuite extends SparkFunSuite { test("struct") { val schema = new StructType() .add("struct", new StructType().add("i", IntegerType).add("str", StringType)) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, None) assert(writer.schema === schema) writer.write(InternalRow(InternalRow(1, UTF8String.fromString("str1")))) @@ -231,7 +231,7 @@ class ArrowWriterSuite extends SparkFunSuite { test("nested struct") { val schema = new StructType().add("struct", new StructType().add("nested", new StructType().add("i", IntegerType).add("str", StringType))) - val writer = ArrowWriter.create(schema) + val writer = ArrowWriter.create(schema, None) assert(writer.schema === schema) writer.write(InternalRow(InternalRow(InternalRow(1, UTF8String.fromString("str1"))))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index d24a9e1f4bd1..33953dc11852 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -29,7 +29,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("boolean") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("boolean", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("boolean", BooleanType, nullable = true) + val vector = ArrowUtils.toArrowField("boolean", BooleanType, nullable = true, None) .createVector(allocator).asInstanceOf[NullableBitVector] vector.allocateNew() val mutator = vector.getMutator() @@ -58,7 +58,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("byte") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("byte", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("byte", ByteType, nullable = true) + val vector = ArrowUtils.toArrowField("byte", ByteType, nullable = true, None) .createVector(allocator).asInstanceOf[NullableTinyIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -87,7 +87,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("short") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("short", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("short", ShortType, nullable = true) + val vector = ArrowUtils.toArrowField("short", ShortType, nullable = true, None) .createVector(allocator).asInstanceOf[NullableSmallIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -116,7 +116,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("int") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true) + val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true, None) .createVector(allocator).asInstanceOf[NullableIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -145,7 +145,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("long") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("long", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("long", LongType, nullable = true) + val vector = ArrowUtils.toArrowField("long", LongType, nullable = true, None) .createVector(allocator).asInstanceOf[NullableBigIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -174,7 +174,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("float") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("float", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("float", FloatType, nullable = true) + val vector = ArrowUtils.toArrowField("float", FloatType, nullable = true, None) .createVector(allocator).asInstanceOf[NullableFloat4Vector] vector.allocateNew() val mutator = vector.getMutator() @@ -203,7 +203,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("double") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("double", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("double", DoubleType, nullable = true) + val vector = ArrowUtils.toArrowField("double", DoubleType, nullable = true, None) .createVector(allocator).asInstanceOf[NullableFloat8Vector] vector.allocateNew() val mutator = vector.getMutator() @@ -232,7 +232,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("string") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("string", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("string", StringType, nullable = true) + val vector = ArrowUtils.toArrowField("string", StringType, nullable = true, None) .createVector(allocator).asInstanceOf[NullableVarCharVector] vector.allocateNew() val mutator = vector.getMutator() @@ -260,7 +260,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("binary") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = true) + val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = true, None) .createVector(allocator).asInstanceOf[NullableVarBinaryVector] vector.allocateNew() val mutator = vector.getMutator() @@ -288,7 +288,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("array") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("array", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true) + val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true, None) .createVector(allocator).asInstanceOf[ListVector] vector.allocateNew() val mutator = vector.getMutator() @@ -345,7 +345,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("struct") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) val schema = new StructType().add("int", IntegerType).add("long", LongType) - val vector = ArrowUtils.toArrowField("struct", schema, nullable = true) + val vector = ArrowUtils.toArrowField("struct", schema, nullable = true, None) .createVector(allocator).asInstanceOf[NullableMapVector] vector.allocateNew() val mutator = vector.getMutator() From c3f4e4d6ea471c5765014582982990c903fb4c83 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 14 Aug 2017 15:05:14 -0700 Subject: [PATCH 17/34] added test for TimestampType without specifying timezone id --- .../apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala index 89285d83d031..21e69d6a3359 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala @@ -46,6 +46,10 @@ class ArrowUtilsSuite extends SparkFunSuite { roundtrip(BinaryType) roundtrip(DecimalType.SYSTEM_DEFAULT) roundtrip(DateType) + val tsExMsg = intercept[UnsupportedOperationException] { + roundtrip(TimestampType) + } + assert(tsExMsg.getMessage.contains("timeZoneId")) } test("timestamp") { From ddbea248a8e65393d7f880de86585fe02add1a7b Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 14 Aug 2017 17:22:23 -0700 Subject: [PATCH 18/34] added date and timestamp to ArrowWriter and tests --- .../vectorized/ArrowColumnVector.java | 34 +++++++++++++++++++ .../sql/execution/arrow/ArrowWriter.scala | 7 +++- .../execution/arrow/ArrowWriterSuite.scala | 8 +++++ 3 files changed, 48 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java index 31dea6ad31b1..d02405df485d 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java @@ -298,6 +298,10 @@ public ArrowColumnVector(ValueVector vector) { accessor = new StringAccessor((NullableVarCharVector) vector); } else if (vector instanceof NullableVarBinaryVector) { accessor = new BinaryAccessor((NullableVarBinaryVector) vector); + } else if (vector instanceof NullableDateDayVector) { + accessor = new DateAccessor((NullableDateDayVector) vector); + } else if (vector instanceof NullableTimeStampMicroTZVector) { + accessor = new TimestampAccessor((NullableTimeStampMicroTZVector) vector); } else if (vector instanceof ListVector) { ListVector listVector = (ListVector) vector; accessor = new ArrayAccessor(listVector); @@ -561,6 +565,36 @@ final byte[] getBinary(int rowId) { } } + private static class DateAccessor extends ArrowVectorAccessor { + + private final NullableDateDayVector.Accessor accessor; + + DateAccessor(NullableDateDayVector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final int getInt(int rowId) { + return accessor.get(rowId); + } + } + + private static class TimestampAccessor extends ArrowVectorAccessor { + + private final NullableTimeStampMicroTZVector.Accessor accessor; + + TimestampAccessor(NullableTimeStampMicroTZVector vector) { + super(vector); + this.accessor = vector.getAccessor(); + } + + @Override + final long getLong(int rowId) { + return accessor.get(rowId); + } + } + private static class ArrayAccessor extends ArrowVectorAccessor { private final UInt4Vector.Accessor accessor; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 331a26071cb5..5fc546e46d13 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -21,6 +21,7 @@ import scala.collection.JavaConverters._ import org.apache.arrow.vector._ import org.apache.arrow.vector.complex._ +import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters @@ -55,7 +56,11 @@ object ArrowWriter { case (StringType, vector: NullableVarCharVector) => new StringWriter(vector) case (BinaryType, vector: NullableVarBinaryVector) => new BinaryWriter(vector) case (DateType, vector: NullableDateDayVector) => new DateWriter(vector) - case (TimestampType, vector: NullableTimeStampMicroTZVector) => new TimestampWriter(vector) + case (TimestampType, vector: NullableTimeStampMicroTZVector) + // TODO: Should be able to access timezone from vector + if field.getType.isInstanceOf[ArrowType.Timestamp] && + field.getType.asInstanceOf[ArrowType.Timestamp].getTimezone != null => + new TimestampWriter(vector) case (ArrayType(_, _), vector: ListVector) => val elementVector = createFieldWriter(vector.getDataVector()) new ArrayWriter(vector, elementVector) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index 006bf6b43061..8abecd4a36e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -51,6 +51,8 @@ class ArrowWriterSuite extends SparkFunSuite { case DoubleType => reader.getDouble(rowId) case StringType => reader.getUTF8String(rowId) case BinaryType => reader.getBinary(rowId) + case DateType => reader.getInt(rowId) + case TimestampType => reader.getLong(rowId) } assert(value === datum) } @@ -66,6 +68,8 @@ class ArrowWriterSuite extends SparkFunSuite { check(DoubleType, Seq(1.0d, 2.0d, null, 4.0d)) check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString)) check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes())) + check(DateType, Seq(0, 1, 2, null, 4)) + check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), Some("America/Los_Angeles")) } test("get multiple") { @@ -88,6 +92,8 @@ class ArrowWriterSuite extends SparkFunSuite { case LongType => reader.getLongs(0, data.size) case FloatType => reader.getFloats(0, data.size) case DoubleType => reader.getDoubles(0, data.size) + case DateType => reader.getInts(0, data.size) + case TimestampType => reader.getLongs(0, data.size) } assert(values === data) @@ -100,6 +106,8 @@ class ArrowWriterSuite extends SparkFunSuite { check(LongType, (0 until 10).map(_.toLong)) check(FloatType, (0 until 10).map(_.toFloat)) check(DoubleType, (0 until 10).map(_.toDouble)) + check(DateType, (0 until 10)) + check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), Some("America/Los_Angeles")) } test("array") { From c6b597d3d0c740e68516df7a2a02984cc614bef9 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 16 Aug 2017 15:36:46 -0700 Subject: [PATCH 19/34] removed unused import --- .../scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index bcd69ce1c912..f7e6182a05d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -23,7 +23,6 @@ import org.apache.arrow.memory.RootAllocator import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, TimeUnit} import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ object ArrowUtils { From d8bae0b0e15c26244454e68fc571d19ee6a1a1b7 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 10 Oct 2017 16:56:32 -0700 Subject: [PATCH 20/34] added Python timezone converions for working with Pandas --- python/pyspark/serializers.py | 7 +++++-- python/pyspark/sql/dataframe.py | 4 +++- python/pyspark/sql/tests.py | 6 ------ python/pyspark/sql/types.py | 34 +++++++++++++++++++++++++++++++++ 4 files changed, 42 insertions(+), 9 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index ad18bd0c81ea..f38539853e8f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -213,6 +213,7 @@ def __repr__(self): def _create_batch(series): + from pyspark.sql.types import _check_convert_series_timestamps import pyarrow as pa # Make input conform to [(series1, type1), (series2, type2), ...] if not isinstance(series, (list, tuple)) or \ @@ -228,7 +229,8 @@ def cast_series(s, t): else: return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) - arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series] + arrs = [pa.Array.from_pandas(_check_convert_series_timestamps(cast_series(s, t)), + mask=s.isnull(), type=t) for s, t in series] return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) @@ -259,11 +261,12 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ + from pyspark.sql.types import _check_localize_series_timestamps import pyarrow as pa reader = pa.open_stream(stream) for batch in reader: table = pa.Table.from_batches([batch]) - yield [c.to_pandas() for c in table.itercolumns()] + yield [_check_localize_series_timestamps(c.to_pandas()) for c in table.itercolumns()] def __repr__(self): return "ArrowStreamPandasSerializer" diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index c53fd2a37e4c..0b319cfc17c9 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1880,11 +1880,13 @@ def toPandas(self): import pandas as pd if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: + from pyspark.sql.types import _check_localize_dataframe_timestamps import pyarrow tables = self._collectAsArrow() if tables: table = pyarrow.concat_tables(tables) - return table.to_pandas() + df = table.to_pandas() + return _check_localize_dataframe_timestamps(df) else: return pd.DataFrame.from_records([], columns=self.columns) except ImportError as e: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 21bf2e299403..df81f1f1fe79 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3141,9 +3141,6 @@ def test_toPandas_arrow_toggle(self): pdf = df.toPandas() self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") pdf_arrow = df.toPandas() - # need to remove timezone for comparison - pdf_arrow["7_timestamp_t"] = \ - pdf_arrow["7_timestamp_t"].apply(lambda ts: ts.tz_localize(None)) self.assertFramesEqual(pdf_arrow, pdf) def test_pandas_round_trip(self): @@ -3158,9 +3155,6 @@ def test_pandas_round_trip(self): pdf = pd.DataFrame(data=data_dict) df = self.spark.createDataFrame(self.data, schema=self.schema) pdf_arrow = df.toPandas() - # need to remove timezone for comparison - pdf_arrow["7_timestamp_t"] = \ - pdf_arrow["7_timestamp_t"].apply(lambda ts: ts.tz_localize(None)) self.assertFramesEqual(pdf_arrow, pdf) def test_filtered_frame(self): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ebdc11c3b744..31b1dc1e04be 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1624,6 +1624,40 @@ def toArrowType(dt): return arrow_type +def _localize_series_timestamps(s): + """ Convert a tz-aware timestamp to local tz-naive + """ + return s.dt.tz_localize(None) + + +def _check_localize_series_timestamps(s): + from pandas.types.common import is_datetime64tz_dtype + # TODO: handle nested timestamps? + return _localize_series_timestamps(s) if is_datetime64tz_dtype(s.dtype) else s + + +def _check_localize_dataframe_timestamps(df): + from pandas.types.common import is_datetime64tz_dtype + for column, series in df.iteritems(): + # TODO: handle nested timestamps? + if is_datetime64tz_dtype(series.dtype): + df[column] = _localize_series_timestamps(series) + return df + + +def _convert_series_timestamps(s): + """ Convert a tz-naive timestamp in local tz to UTC normalized + """ + # TODO: this should be system local tz or SESSION_LOCAL_TIMEZONE? + return s.dt.tz_convert("UTC") + + +def _check_convert_series_timestamps(s): + from pandas.types.common import is_datetime64_dtype + # TODO: handle nested timestamps? + return _convert_series_timestamps(s) if is_datetime64_dtype(s.dtype) else s + + def _test(): import doctest from pyspark.context import SparkContext From c4fd5ae1e61d118877151d713e8ea0e53a7afaf4 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 10 Oct 2017 17:17:37 -0700 Subject: [PATCH 21/34] fix compilation --- .../spark/sql/execution/python/FlatMapGroupsInPandasExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index b996b5bb38ba..78073ed8d71c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -94,7 +94,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, Option(conf.sessionLocalTimeZone)) .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) From d1617fde80697bc3423d661c23997300cbb896ce Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 11 Oct 2017 09:33:55 -0700 Subject: [PATCH 22/34] fixed test comp --- .../spark/sql/execution/vectorized/ColumnarBatchSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 0b179aa97c47..7e32ae7667f2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1249,11 +1249,11 @@ class ColumnarBatchSuite extends SparkFunSuite { test("create columnar batch from Arrow column vectors") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) - val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true) + val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true, None) .createVector(allocator).asInstanceOf[NullableIntVector] vector1.allocateNew() val mutator1 = vector1.getMutator() - val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true) + val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true, None) .createVector(allocator).asInstanceOf[NullableIntVector] vector2.allocateNew() val mutator2 = vector2.getMutator() From d7d9b477894d51ef937a84d54e01a41d4e048ba8 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 11 Oct 2017 09:45:00 -0700 Subject: [PATCH 23/34] add conversion to Python system local timezone before localize --- python/pyspark/sql/types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index abd16d1d0932..0612c54dec57 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1627,7 +1627,7 @@ def to_arrow_type(dt): def _localize_series_timestamps(s): """ Convert a tz-aware timestamp to local tz-naive """ - return s.dt.tz_localize(None) + return s.dt.tz_convert('tzlocal()').dt.tz_localize(None) def _check_localize_series_timestamps(s): From efe3e27a1f374e4482cffe2ce3877aceffc5eaad Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 11 Oct 2017 16:45:46 -0700 Subject: [PATCH 24/34] timestamps with Arrow almost working for pandas_udfs --- python/pyspark/serializers.py | 6 +++--- python/pyspark/sql/tests.py | 37 +++++++++++++++++++++++++++++++++++ python/pyspark/sql/types.py | 28 ++++++++++++++------------ 3 files changed, 55 insertions(+), 16 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index f38539853e8f..a8ce0eb1f9a3 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -213,7 +213,7 @@ def __repr__(self): def _create_batch(series): - from pyspark.sql.types import _check_convert_series_timestamps + from pyspark.sql.types import _check_utc_normalize_series_timestamps import pyarrow as pa # Make input conform to [(series1, type1), (series2, type2), ...] if not isinstance(series, (list, tuple)) or \ @@ -224,12 +224,12 @@ def _create_batch(series): # If a nullable integer series has been promoted to floating point with NaNs, need to cast # NOTE: this is not necessary with Arrow >= 0.7 def cast_series(s, t): - if t is None or s.dtype == t.to_pandas_dtype(): + if t is None or s.dtype == t.to_pandas_dtype() or type(t) == pa.TimestampType: return s else: return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) - arrs = [pa.Array.from_pandas(_check_convert_series_timestamps(cast_series(s, t)), + arrs = [pa.Array.from_pandas(_check_utc_normalize_series_timestamps(cast_series(s, t)), mask=s.isnull(), type=t) for s, t in series] return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f95b2a795836..47391154e8de 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3400,6 +3400,43 @@ def test_vectorized_udf_varargs(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_timestamps(self): + from pyspark.sql.functions import pandas_udf, col + from datetime import date, datetime + schema = StructType([ + StructField("idx", LongType(), True), + StructField("date", DateType(), True), + StructField("timestamp", TimestampType(), True)]) + # TODO Fails with time before epoch: (0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)) + data = [(0, date(1985, 1, 1), datetime(1985, 1, 1, 1, 1, 1)), + (1, date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), + (2, date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3)), + (3, date(2104, 4, 4), datetime(2104, 4, 4, 4, 4, 4))] + + df = self.spark.createDataFrame(data, schema=schema) + + # Check that a timestamp passed through a pandas_udf will not be altered by timezone calc + identity = pandas_udf(lambda t: t, returnType=TimestampType()) + df = df.withColumn("timestamp_copy", identity(col("timestamp"))) + + @pandas_udf(returnType=BooleanType()) + def check_data(idx, date, timestamp, timestamp_copy): + is_equal = timestamp == timestamp_copy + if is_equal.all(): + for i in xrange(len(is_equal)): + # TODO Fails with tz offset: date[i].date() == data[idx[i]][1] and + is_equal[i] = timestamp[i].to_pydatetime() == data[idx[i]][2] + return is_equal + + result = df.withColumn("is_equal", check_data(col("idx"), col("date"), col("timestamp"), + col("timestamp_copy"))).collect() + # Check that collection values are correct + self.assertEquals(len(data), len(result)) + for i in range(len(result)): + self.assertEquals(data[i][1], result[i][1]) # "date" col + self.assertEquals(data[i][2], result[i][2]) # "timestamp" col + self.assertTrue(result[i][4]) # "is_equal" data in udf was as expected + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedPySparkTestCase): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 0612c54dec57..b392b0f42d8b 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1619,21 +1619,23 @@ def to_arrow_type(dt): arrow_type = pa.decimal(dt.precision, dt.scale) elif type(dt) == StringType: arrow_type = pa.string() + elif type(dt) == DateType: + arrow_type = pa.date32() + elif type(dt) == TimestampType: + arrow_type = pa.timestamp('us', tz='UTC') else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) return arrow_type -def _localize_series_timestamps(s): - """ Convert a tz-aware timestamp to local tz-naive - """ - return s.dt.tz_convert('tzlocal()').dt.tz_localize(None) - - def _check_localize_series_timestamps(s): - from pandas.types.common import is_datetime64tz_dtype + from pandas.types.common import is_datetime64_dtype # TODO: handle nested timestamps? - return _localize_series_timestamps(s) if is_datetime64tz_dtype(s.dtype) else s + if is_datetime64_dtype(s.dtype): + # TODO: pyarrow.Column.to_pandas keeps data in UTC but removes timezone + return s.dt.tz_localize('UTC').dt.tz_convert('tzlocal()').dt.tz_localize(None) + else: + return s def _check_localize_dataframe_timestamps(df): @@ -1641,21 +1643,21 @@ def _check_localize_dataframe_timestamps(df): for column, series in df.iteritems(): # TODO: handle nested timestamps? if is_datetime64tz_dtype(series.dtype): - df[column] = _localize_series_timestamps(series) + df[column] = series.dt.tz_convert('tzlocal()').dt.tz_localize(None) return df -def _convert_series_timestamps(s): +def _utc_normalize_series_timestamps(s): """ Convert a tz-naive timestamp in local tz to UTC normalized """ # TODO: this should be system local tz or SESSION_LOCAL_TIMEZONE? - return s.dt.tz_convert("UTC") + return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC').values.astype('datetime64[us]') -def _check_convert_series_timestamps(s): +def _check_utc_normalize_series_timestamps(s): from pandas.types.common import is_datetime64_dtype # TODO: handle nested timestamps? - return _convert_series_timestamps(s) if is_datetime64_dtype(s.dtype) else s + return _utc_normalize_series_timestamps(s) if is_datetime64_dtype(s.dtype) else s def _test(): From 989451915f9fe10776162bf19cdc22f4e6e749f9 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 16 Oct 2017 17:04:10 -0700 Subject: [PATCH 25/34] added workaround for Series to_pandas with timestamps, store os.environ TZ during test --- python/pyspark/serializers.py | 11 ++++++----- python/pyspark/sql/dataframe.py | 4 ++-- python/pyspark/sql/tests.py | 10 ++++++---- python/pyspark/sql/types.py | 29 ++++++++++------------------- 4 files changed, 24 insertions(+), 30 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index a8ce0eb1f9a3..68f24d05debe 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -213,7 +213,7 @@ def __repr__(self): def _create_batch(series): - from pyspark.sql.types import _check_utc_normalize_series_timestamps + from pyspark.sql.types import _check_series_convert_timestamps_internal import pyarrow as pa # Make input conform to [(series1, type1), (series2, type2), ...] if not isinstance(series, (list, tuple)) or \ @@ -229,7 +229,7 @@ def cast_series(s, t): else: return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) - arrs = [pa.Array.from_pandas(_check_utc_normalize_series_timestamps(cast_series(s, t)), + arrs = [pa.Array.from_pandas(_check_series_convert_timestamps_internal(cast_series(s, t)), mask=s.isnull(), type=t) for s, t in series] return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) @@ -261,12 +261,13 @@ def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - from pyspark.sql.types import _check_localize_series_timestamps + from pyspark.sql.types import _check_dataframe_localize_timestamps import pyarrow as pa reader = pa.open_stream(stream) for batch in reader: - table = pa.Table.from_batches([batch]) - yield [_check_localize_series_timestamps(c.to_pandas()) for c in table.itercolumns()] + # NOTE: changed from pa.Columns.to_pandas, timezone issue in conversion fixed in 0.7.1 + pdf = _check_dataframe_localize_timestamps(batch.to_pandas()) + yield [c for _, c in pdf.iteritems()] def __repr__(self): return "ArrowStreamPandasSerializer" diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 8e56c1ae89af..0a3888056317 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1880,13 +1880,13 @@ def toPandas(self): import pandas as pd if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: - from pyspark.sql.types import _check_localize_dataframe_timestamps + from pyspark.sql.types import _check_dataframe_localize_timestamps import pyarrow tables = self._collectAsArrow() if tables: table = pyarrow.concat_tables(tables) df = table.to_pandas() - return _check_localize_dataframe_timestamps(df) + return _check_dataframe_localize_timestamps(df) else: return pd.DataFrame.from_records([], columns=self.columns) except ImportError as e: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 47391154e8de..7e147ca092e7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3090,6 +3090,7 @@ def setUpClass(cls): ReusedPySparkTestCase.setUpClass() # Synchronize default timezone between Python and Java + cls.tz_prev = os.environ.get("TZ", None) # save current tz if set tz = "America/Los_Angeles" os.environ["TZ"] = tz time.tzset() @@ -3112,6 +3113,8 @@ def setUpClass(cls): @classmethod def tearDownClass(cls): del os.environ["TZ"] + if cls.tz_prev is not None: + os.environ["TZ"] = cls.tz_prev time.tzset() ReusedPySparkTestCase.tearDownClass() cls.spark.stop() @@ -3407,8 +3410,7 @@ def test_vectorized_udf_timestamps(self): StructField("idx", LongType(), True), StructField("date", DateType(), True), StructField("timestamp", TimestampType(), True)]) - # TODO Fails with time before epoch: (0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)) - data = [(0, date(1985, 1, 1), datetime(1985, 1, 1, 1, 1, 1)), + data = [(0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), (1, date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), (2, date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3)), (3, date(2104, 4, 4), datetime(2104, 4, 4, 4, 4, 4))] @@ -3424,8 +3426,8 @@ def check_data(idx, date, timestamp, timestamp_copy): is_equal = timestamp == timestamp_copy if is_equal.all(): for i in xrange(len(is_equal)): - # TODO Fails with tz offset: date[i].date() == data[idx[i]][1] and - is_equal[i] = timestamp[i].to_pydatetime() == data[idx[i]][2] + is_equal[i] = date[i].date() == data[idx[i]][1] \ + and timestamp[i].to_pydatetime() == data[idx[i]][2] return is_equal result = df.withColumn("is_equal", check_data(col("idx"), col("date"), col("timestamp"), diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index b392b0f42d8b..7581b6adc072 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1622,23 +1622,16 @@ def to_arrow_type(dt): elif type(dt) == DateType: arrow_type = pa.date32() elif type(dt) == TimestampType: + # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read arrow_type = pa.timestamp('us', tz='UTC') else: raise TypeError("Unsupported type in conversion to Arrow: " + str(dt)) return arrow_type -def _check_localize_series_timestamps(s): - from pandas.types.common import is_datetime64_dtype - # TODO: handle nested timestamps? - if is_datetime64_dtype(s.dtype): - # TODO: pyarrow.Column.to_pandas keeps data in UTC but removes timezone - return s.dt.tz_localize('UTC').dt.tz_convert('tzlocal()').dt.tz_localize(None) - else: - return s - - -def _check_localize_dataframe_timestamps(df): +def _check_dataframe_localize_timestamps(df): + """ Convert timezone aware timestamps to timezone-naive in local time + """ from pandas.types.common import is_datetime64tz_dtype for column, series in df.iteritems(): # TODO: handle nested timestamps? @@ -1647,17 +1640,15 @@ def _check_localize_dataframe_timestamps(df): return df -def _utc_normalize_series_timestamps(s): - """ Convert a tz-naive timestamp in local tz to UTC normalized +def _check_series_convert_timestamps_internal(s): + """ Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage """ - # TODO: this should be system local tz or SESSION_LOCAL_TIMEZONE? - return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC').values.astype('datetime64[us]') - - -def _check_utc_normalize_series_timestamps(s): from pandas.types.common import is_datetime64_dtype # TODO: handle nested timestamps? - return _utc_normalize_series_timestamps(s) if is_datetime64_dtype(s.dtype) else s + if is_datetime64_dtype(s.dtype): + return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC').values.astype('datetime64[us]') + else: + return s def _test(): From a3ba4accc97bf6d9b9dee6aa6f310619756ab276 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 16 Oct 2017 22:08:29 -0700 Subject: [PATCH 26/34] change use of xrange for py3 --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7e147ca092e7..4a2878394472 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3425,7 +3425,7 @@ def test_vectorized_udf_timestamps(self): def check_data(idx, date, timestamp, timestamp_copy): is_equal = timestamp == timestamp_copy if is_equal.all(): - for i in xrange(len(is_equal)): + for i in range(len(is_equal)): is_equal[i] = date[i].date() == data[idx[i]][1] \ and timestamp[i].to_pydatetime() == data[idx[i]][2] return is_equal From 7266304fc68f3e5ca26a4dda543f18837b013e29 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 17 Oct 2017 13:29:20 -0700 Subject: [PATCH 27/34] remove check for valid timezone in vector for ArrowWriter --- .../org/apache/spark/sql/execution/arrow/ArrowWriter.scala | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index d4629a92c944..43cb1760c84d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -56,11 +56,7 @@ object ArrowWriter { case (StringType, vector: NullableVarCharVector) => new StringWriter(vector) case (BinaryType, vector: NullableVarBinaryVector) => new BinaryWriter(vector) case (DateType, vector: NullableDateDayVector) => new DateWriter(vector) - case (TimestampType, vector: NullableTimeStampMicroTZVector) - // TODO: Should be able to access timezone from vector - if field.getType.isInstanceOf[ArrowType.Timestamp] && - field.getType.asInstanceOf[ArrowType.Timestamp].getTimezone != null => - new TimestampWriter(vector) + case (TimestampType, vector: NullableTimeStampMicroTZVector) => new TimestampWriter(vector) case (ArrayType(_, _), vector: ListVector) => val elementVector = createFieldWriter(vector.getDataVector()) new ArrayWriter(vector, elementVector) From e428cbe2f8e626510e6660b38faff7f03b229e92 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 17 Oct 2017 15:21:36 -0700 Subject: [PATCH 28/34] added note for 'us' conversion --- python/pyspark/sql/types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 7581b6adc072..702b135fd8d1 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1646,6 +1646,7 @@ def _check_series_convert_timestamps_internal(s): from pandas.types.common import is_datetime64_dtype # TODO: handle nested timestamps? if is_datetime64_dtype(s.dtype): + # NOTE: convert to 'us' with astype here, unit is ignored in `from_pandas` see ARROW-1680 return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC').values.astype('datetime64[us]') else: return s From cade921bad78d45fc9d380363a88f94c4f2cbb79 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 19 Oct 2017 11:35:13 -0700 Subject: [PATCH 29/34] changed python api for is_datetime64 --- python/pyspark/sql/types.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 702b135fd8d1..a9e9aa289b8f 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1632,7 +1632,7 @@ def to_arrow_type(dt): def _check_dataframe_localize_timestamps(df): """ Convert timezone aware timestamps to timezone-naive in local time """ - from pandas.types.common import is_datetime64tz_dtype + from pandas.api.types import is_datetime64tz_dtype for column, series in df.iteritems(): # TODO: handle nested timestamps? if is_datetime64tz_dtype(series.dtype): @@ -1643,7 +1643,7 @@ def _check_dataframe_localize_timestamps(df): def _check_series_convert_timestamps_internal(s): """ Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage """ - from pandas.types.common import is_datetime64_dtype + from pandas.api.types import is_datetime64_dtype # TODO: handle nested timestamps? if is_datetime64_dtype(s.dtype): # NOTE: convert to 'us' with astype here, unit is ignored in `from_pandas` see ARROW-1680 From f512deb97f458043a6825bac8a44c1392f6c910b Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 19 Oct 2017 12:05:25 -0700 Subject: [PATCH 30/34] remove Option for timezoneId --- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../sql/execution/arrow/ArrowConverters.scala | 2 +- .../sql/execution/arrow/ArrowUtils.scala | 14 ++++++------ .../sql/execution/arrow/ArrowWriter.scala | 2 +- .../python/ArrowEvalPythonExec.scala | 2 +- .../execution/python/ArrowPythonRunner.scala | 2 +- .../python/FlatMapGroupsInPandasExec.scala | 2 +- .../arrow/ArrowConvertersSuite.scala | 8 +++---- .../sql/execution/arrow/ArrowUtilsSuite.scala | 4 ++-- .../execution/arrow/ArrowWriterSuite.scala | 16 +++++++------- .../vectorized/ArrowColumnVectorSuite.scala | 22 +++++++++---------- .../vectorized/ColumnarBatchSuite.scala | 4 ++-- 12 files changed, 40 insertions(+), 40 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 3d3d6325fa82..12f09ba3d8c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -3143,7 +3143,7 @@ class Dataset[T] private[sql]( private[sql] def toArrowPayload: RDD[ArrowPayload] = { val schemaCaptured = this.schema val maxRecordsPerBatch = sparkSession.sessionState.conf.arrowMaxRecordsPerBatch - val timeZoneId = Option(sparkSession.sessionState.conf.sessionLocalTimeZone) + val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone queryExecution.toRdd.mapPartitionsInternal { iter => val context = TaskContext.get() ArrowConverters.toPayloadIterator( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 39af2769f828..05ea1517fcac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -74,7 +74,7 @@ private[sql] object ArrowConverters { rowIter: Iterator[InternalRow], schema: StructType, maxRecordsPerBatch: Int, - timeZoneId: Option[String], + timeZoneId: String, context: TaskContext): Iterator[ArrowPayload] = { val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala index f7e6182a05d8..6ad11bda84bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala @@ -32,7 +32,7 @@ object ArrowUtils { // todo: support more types. /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ - def toArrowType(dt: DataType, timeZoneId: Option[String]): ArrowType = dt match { + def toArrowType(dt: DataType, timeZoneId: String): ArrowType = dt match { case BooleanType => ArrowType.Bool.INSTANCE case ByteType => new ArrowType.Int(8, true) case ShortType => new ArrowType.Int(8 * 2, true) @@ -45,10 +45,10 @@ object ArrowUtils { case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) case DateType => new ArrowType.Date(DateUnit.DAY) case TimestampType => - timeZoneId match { - case Some(id) => new ArrowType.Timestamp(TimeUnit.MICROSECOND, id) - case None => - throw new UnsupportedOperationException("TimestampType must supply timeZoneId parameter") + if (timeZoneId == null) { + throw new UnsupportedOperationException("TimestampType must supply timeZoneId parameter") + } else { + new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) } case _ => throw new UnsupportedOperationException(s"Unsupported data type: ${dt.simpleString}") } @@ -73,7 +73,7 @@ object ArrowUtils { /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ def toArrowField( - name: String, dt: DataType, nullable: Boolean, timeZoneId: Option[String]): Field = { + name: String, dt: DataType, nullable: Boolean, timeZoneId: String): Field = { dt match { case ArrayType(elementType, containsNull) => val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) @@ -108,7 +108,7 @@ object ArrowUtils { } /** Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType */ - def toArrowSchema(schema: StructType, timeZoneId: Option[String]): Schema = { + def toArrowSchema(schema: StructType, timeZoneId: String): Schema = { new Schema(schema.map { field => toArrowField(field.name, field.dataType, field.nullable, timeZoneId) }.asJava) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 43cb1760c84d..e4af4f65da12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.types._ object ArrowWriter { - def create(schema: StructType, timeZoneId: Option[String]): ArrowWriter = { + def create(schema: StructType, timeZoneId: String): ArrowWriter = { val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val root = VectorSchemaRoot.create(arrowSchema, ArrowUtils.rootAllocator) create(root) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index cd33d26ded35..0db463a5fbd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -79,7 +79,7 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi val columnarBatchIter = new ArrowPythonRunner( funcs, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, Option(conf.sessionLocalTimeZone)) + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, conf.sessionLocalTimeZone) .compute(batchIter, context.partitionId(), context) new Iterator[InternalRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index ce4d1828917d..94c05b9b5e49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -44,7 +44,7 @@ class ArrowPythonRunner( evalType: Int, argOffsets: Array[Array[Int]], schema: StructType, - timeZoneId: Option[String]) + timeZoneId: String) extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, bufferSize, reuseWorker, evalType, argOffsets) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 78073ed8d71c..fc19ae6b56c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -94,7 +94,7 @@ case class FlatMapGroupsInPandasExec( val columnarBatchIter = new ArrowPythonRunner( chainedFunc, bufferSize, reuseWorker, - PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, Option(conf.sessionLocalTimeZone)) + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema, conf.sessionLocalTimeZone) .compute(grouped, context.partitionId(), context) columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index 54aa7997d40d..ba2903babbba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -888,7 +888,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val df = data.toDF("timestamp") - collectAndValidate(df, json, "timestampData.json", Option("America/Los_Angeles")) + collectAndValidate(df, json, "timestampData.json", "America/Los_Angeles") } } @@ -1728,7 +1728,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { val schema = StructType(Seq(StructField("int", IntegerType, nullable = true))) val ctx = TaskContext.empty() - val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, None, ctx) + val payloadIter = ArrowConverters.toPayloadIterator(inputRows.toIterator, schema, 0, null, ctx) val outputRowIter = ArrowConverters.fromPayloadIterator(payloadIter, ctx) assert(schema.equals(outputRowIter.schema)) @@ -1748,7 +1748,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { /** Test that a converted DataFrame to Arrow record batch equals batch read from JSON file */ private def collectAndValidate( - df: DataFrame, json: String, file: String, timeZoneId: Option[String] = None): Unit = { + df: DataFrame, json: String, file: String, timeZoneId: String = null): Unit = { // NOTE: coalesce to single partition because can only load 1 batch in validator val arrowPayload = df.coalesce(1).toArrowPayload.collect().head val tempFile = new File(tempDataPath, file) @@ -1760,7 +1760,7 @@ class ArrowConvertersSuite extends SharedSQLContext with BeforeAndAfterAll { sparkSchema: StructType, arrowPayload: ArrowPayload, jsonFile: File, - timeZoneId: Option[String] = None): Unit = { + timeZoneId: String = null): Unit = { val allocator = new RootAllocator(Long.MaxValue) val jsonReader = new JsonFileReader(jsonFile, allocator) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala index 21e69d6a3359..d801f62b6232 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala @@ -28,7 +28,7 @@ class ArrowUtilsSuite extends SparkFunSuite { def roundtrip(dt: DataType): Unit = { dt match { case schema: StructType => - assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, None)) === schema) + assert(ArrowUtils.fromArrowSchema(ArrowUtils.toArrowSchema(schema, null)) === schema) case _ => roundtrip(new StructType().add("value", dt)) } @@ -56,7 +56,7 @@ class ArrowUtilsSuite extends SparkFunSuite { def roundtripWithTz(timeZoneId: String): Unit = { val schema = new StructType().add("value", TimestampType) - val arrowSchema = ArrowUtils.toArrowSchema(schema, Option(timeZoneId)) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) val fieldType = arrowSchema.findField("value").getType.asInstanceOf[ArrowType.Timestamp] assert(fieldType.getTimezone() === timeZoneId) assert(ArrowUtils.fromArrowSchema(arrowSchema) === schema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index 8abecd4a36e3..a71e30aa3ca9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.unsafe.types.UTF8String class ArrowWriterSuite extends SparkFunSuite { test("simple") { - def check(dt: DataType, data: Seq[Any], timeZoneId: Option[String] = None): Unit = { + def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = { val schema = new StructType().add("value", dt, nullable = true) val writer = ArrowWriter.create(schema, timeZoneId) assert(writer.schema === schema) @@ -69,11 +69,11 @@ class ArrowWriterSuite extends SparkFunSuite { check(StringType, Seq("a", "b", null, "d").map(UTF8String.fromString)) check(BinaryType, Seq("a".getBytes(), "b".getBytes(), null, "d".getBytes())) check(DateType, Seq(0, 1, 2, null, 4)) - check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), Some("America/Los_Angeles")) + check(TimestampType, Seq(0L, 3.6e9.toLong, null, 8.64e10.toLong), "America/Los_Angeles") } test("get multiple") { - def check(dt: DataType, data: Seq[Any], timeZoneId: Option[String] = None): Unit = { + def check(dt: DataType, data: Seq[Any], timeZoneId: String = null): Unit = { val schema = new StructType().add("value", dt, nullable = false) val writer = ArrowWriter.create(schema, timeZoneId) assert(writer.schema === schema) @@ -107,13 +107,13 @@ class ArrowWriterSuite extends SparkFunSuite { check(FloatType, (0 until 10).map(_.toFloat)) check(DoubleType, (0 until 10).map(_.toDouble)) check(DateType, (0 until 10)) - check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), Some("America/Los_Angeles")) + check(TimestampType, (0 until 10).map(_ * 4.32e10.toLong), "America/Los_Angeles") } test("array") { val schema = new StructType() .add("arr", ArrayType(IntegerType, containsNull = true), nullable = true) - val writer = ArrowWriter.create(schema, None) + val writer = ArrowWriter.create(schema, null) assert(writer.schema === schema) writer.write(InternalRow(ArrayData.toArrayData(Array(1, 2, 3)))) @@ -152,7 +152,7 @@ class ArrowWriterSuite extends SparkFunSuite { test("nested array") { val schema = new StructType().add("nested", ArrayType(ArrayType(IntegerType))) - val writer = ArrowWriter.create(schema, None) + val writer = ArrowWriter.create(schema, null) assert(writer.schema === schema) writer.write(InternalRow(ArrayData.toArrayData(Array( @@ -203,7 +203,7 @@ class ArrowWriterSuite extends SparkFunSuite { test("struct") { val schema = new StructType() .add("struct", new StructType().add("i", IntegerType).add("str", StringType)) - val writer = ArrowWriter.create(schema, None) + val writer = ArrowWriter.create(schema, null) assert(writer.schema === schema) writer.write(InternalRow(InternalRow(1, UTF8String.fromString("str1")))) @@ -239,7 +239,7 @@ class ArrowWriterSuite extends SparkFunSuite { test("nested struct") { val schema = new StructType().add("struct", new StructType().add("nested", new StructType().add("i", IntegerType).add("str", StringType))) - val writer = ArrowWriter.create(schema, None) + val writer = ArrowWriter.create(schema, null) assert(writer.schema === schema) writer.write(InternalRow(InternalRow(InternalRow(1, UTF8String.fromString("str1"))))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 33953dc11852..068a17bf772e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -29,7 +29,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("boolean") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("boolean", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("boolean", BooleanType, nullable = true, None) + val vector = ArrowUtils.toArrowField("boolean", BooleanType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableBitVector] vector.allocateNew() val mutator = vector.getMutator() @@ -58,7 +58,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("byte") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("byte", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("byte", ByteType, nullable = true, None) + val vector = ArrowUtils.toArrowField("byte", ByteType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableTinyIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -87,7 +87,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("short") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("short", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("short", ShortType, nullable = true, None) + val vector = ArrowUtils.toArrowField("short", ShortType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableSmallIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -116,7 +116,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("int") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true, None) + val vector = ArrowUtils.toArrowField("int", IntegerType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -145,7 +145,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("long") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("long", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("long", LongType, nullable = true, None) + val vector = ArrowUtils.toArrowField("long", LongType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableBigIntVector] vector.allocateNew() val mutator = vector.getMutator() @@ -174,7 +174,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("float") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("float", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("float", FloatType, nullable = true, None) + val vector = ArrowUtils.toArrowField("float", FloatType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableFloat4Vector] vector.allocateNew() val mutator = vector.getMutator() @@ -203,7 +203,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("double") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("double", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("double", DoubleType, nullable = true, None) + val vector = ArrowUtils.toArrowField("double", DoubleType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableFloat8Vector] vector.allocateNew() val mutator = vector.getMutator() @@ -232,7 +232,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("string") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("string", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("string", StringType, nullable = true, None) + val vector = ArrowUtils.toArrowField("string", StringType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableVarCharVector] vector.allocateNew() val mutator = vector.getMutator() @@ -260,7 +260,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("binary") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("binary", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = true, None) + val vector = ArrowUtils.toArrowField("binary", BinaryType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableVarBinaryVector] vector.allocateNew() val mutator = vector.getMutator() @@ -288,7 +288,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("array") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("array", 0, Long.MaxValue) - val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true, None) + val vector = ArrowUtils.toArrowField("array", ArrayType(IntegerType), nullable = true, null) .createVector(allocator).asInstanceOf[ListVector] vector.allocateNew() val mutator = vector.getMutator() @@ -345,7 +345,7 @@ class ArrowColumnVectorSuite extends SparkFunSuite { test("struct") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("struct", 0, Long.MaxValue) val schema = new StructType().add("int", IntegerType).add("long", LongType) - val vector = ArrowUtils.toArrowField("struct", schema, nullable = true, None) + val vector = ArrowUtils.toArrowField("struct", schema, nullable = true, null) .createVector(allocator).asInstanceOf[NullableMapVector] vector.allocateNew() val mutator = vector.getMutator() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 7e32ae7667f2..4cfc776e51db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -1249,11 +1249,11 @@ class ColumnarBatchSuite extends SparkFunSuite { test("create columnar batch from Arrow column vectors") { val allocator = ArrowUtils.rootAllocator.newChildAllocator("int", 0, Long.MaxValue) - val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true, None) + val vector1 = ArrowUtils.toArrowField("int1", IntegerType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableIntVector] vector1.allocateNew() val mutator1 = vector1.getMutator() - val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true, None) + val vector2 = ArrowUtils.toArrowField("int2", IntegerType, nullable = true, null) .createVector(allocator).asInstanceOf[NullableIntVector] vector2.allocateNew() val mutator2 = vector2.getMutator() From 79bb93f36ad6a0096f59072c54097015e2099a73 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 23 Oct 2017 15:02:17 -0700 Subject: [PATCH 31/34] added pandas_udf test for date --- python/pyspark/serializers.py | 13 +++++--- python/pyspark/sql/tests.py | 58 ++++++++++++++++++++--------------- python/pyspark/sql/types.py | 12 ++------ 3 files changed, 46 insertions(+), 37 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 8e314d798382..a1ad7415ac00 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -214,7 +214,7 @@ def __repr__(self): def _create_batch(series): - from pyspark.sql.types import _check_series_convert_timestamps_internal + from pyspark.sql.types import _series_convert_timestamps_internal import pyarrow as pa # Make input conform to [(series1, type1), (series2, type2), ...] if not isinstance(series, (list, tuple)) or \ @@ -225,13 +225,18 @@ def _create_batch(series): # If a nullable integer series has been promoted to floating point with NaNs, need to cast # NOTE: this is not necessary with Arrow >= 0.7 def cast_series(s, t): - if t is None or s.dtype == t.to_pandas_dtype() or type(t) == pa.TimestampType: + if type(t) == pa.TimestampType: + # NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680 + return _series_convert_timestamps_internal(s).values.astype('datetime64[us]') + elif t == pa.date32(): + # TODO: ValueError: Cannot cast DatetimeIndex to dtype datetime64[D] + return s.dt.values.astype('datetime64[D]') + elif t is None or s.dtype == t.to_pandas_dtype(): return s else: return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) - arrs = [pa.Array.from_pandas(_check_series_convert_timestamps_internal(cast_series(s, t)), - mask=s.isnull(), type=t) for s, t in series] + arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series] return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 19b195a73596..0b8517c619d9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3405,48 +3405,58 @@ def test_vectorized_udf_varargs(self): def test_vectorized_udf_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col - schema = StructType([StructField("dt", DateType(), True)]) - df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) - f = pandas_udf(lambda x: x, DateType()) + schema = StructType([StructField("dt", DecimalType(), True)]) + df = self.spark.createDataFrame([(None,)], schema=schema) + f = pandas_udf(lambda x: x, DecimalType()) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): df.select(f(col('dt'))).collect() + def test_vectorized_udf_null_date(self): + from pyspark.sql.functions import pandas_udf, col + from datetime import date + schema = StructType().add("date", DateType()) + data = [(date(1969, 1, 1),), + (date(2012, 2, 2),), + (None,), + (date(2100, 4, 4),)] + df = self.spark.createDataFrame(data, schema=schema) + date_f = pandas_udf(lambda t: t, returnType=DateType()) + res = df.select(date_f(col("date"))) + self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_timestamps(self): from pyspark.sql.functions import pandas_udf, col - from datetime import date, datetime + from datetime import datetime schema = StructType([ StructField("idx", LongType(), True), - StructField("date", DateType(), True), StructField("timestamp", TimestampType(), True)]) - data = [(0, date(1969, 1, 1), datetime(1969, 1, 1, 1, 1, 1)), - (1, date(2012, 2, 2), datetime(2012, 2, 2, 2, 2, 2)), - (2, date(2100, 3, 3), datetime(2100, 3, 3, 3, 3, 3)), - (3, date(2104, 4, 4), datetime(2104, 4, 4, 4, 4, 4))] - + data = [(0, datetime(1969, 1, 1, 1, 1, 1)), + (1, datetime(2012, 2, 2, 2, 2, 2)), + (2, None), + (3, datetime(2100, 4, 4, 4, 4, 4))] df = self.spark.createDataFrame(data, schema=schema) # Check that a timestamp passed through a pandas_udf will not be altered by timezone calc - identity = pandas_udf(lambda t: t, returnType=TimestampType()) - df = df.withColumn("timestamp_copy", identity(col("timestamp"))) + f_timestamp_copy = pandas_udf(lambda t: t, returnType=TimestampType()) + df = df.withColumn("timestamp_copy", f_timestamp_copy(col("timestamp"))) @pandas_udf(returnType=BooleanType()) - def check_data(idx, date, timestamp, timestamp_copy): - is_equal = timestamp == timestamp_copy - if is_equal.all(): - for i in range(len(is_equal)): - is_equal[i] = date[i].date() == data[idx[i]][1] \ - and timestamp[i].to_pydatetime() == data[idx[i]][2] + def check_data(idx, timestamp, timestamp_copy): + is_equal = timestamp.isnull() # use this array to check values are equal + for i in range(len(idx)): + # Check that timestamps are as expected in the UDF + is_equal[i] = (is_equal[i] and data[idx[i]][1] is None) or \ + timestamp[i].to_pydatetime() == data[idx[i]][1] return is_equal - result = df.withColumn("is_equal", check_data(col("idx"), col("date"), col("timestamp"), + result = df.withColumn("is_equal", check_data(col("idx"), col("timestamp"), col("timestamp_copy"))).collect() # Check that collection values are correct self.assertEquals(len(data), len(result)) for i in range(len(result)): - self.assertEquals(data[i][1], result[i][1]) # "date" col - self.assertEquals(data[i][2], result[i][2]) # "timestamp" col - self.assertTrue(result[i][4]) # "is_equal" data in udf was as expected + self.assertEquals(data[i][1], result[i][1]) # "timestamp" col + self.assertTrue(result[i][3]) # "is_equal" data in udf was as expected @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") @@ -3606,8 +3616,8 @@ def test_wrong_args(self): def test_unsupported_types(self): from pyspark.sql.functions import pandas_udf, col schema = StructType( - [StructField("id", LongType(), True), StructField("dt", DateType(), True)]) - df = self.spark.createDataFrame([(1, datetime.date(1970, 1, 1),)], schema=schema) + [StructField("id", LongType(), True), StructField("dt", DecimalType(), True)]) + df = self.spark.createDataFrame([(1, None,)], schema=schema) f = pandas_udf(lambda x: x, df.schema) with QuietTest(self.sc): with self.assertRaisesRegexp(Exception, 'Unsupported data type'): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index a9e9aa289b8f..a70e7bbe6096 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1634,22 +1634,16 @@ def _check_dataframe_localize_timestamps(df): """ from pandas.api.types import is_datetime64tz_dtype for column, series in df.iteritems(): - # TODO: handle nested timestamps? + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64tz_dtype(series.dtype): df[column] = series.dt.tz_convert('tzlocal()').dt.tz_localize(None) return df -def _check_series_convert_timestamps_internal(s): +def _series_convert_timestamps_internal(s): """ Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage """ - from pandas.api.types import is_datetime64_dtype - # TODO: handle nested timestamps? - if is_datetime64_dtype(s.dtype): - # NOTE: convert to 'us' with astype here, unit is ignored in `from_pandas` see ARROW-1680 - return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC').values.astype('datetime64[us]') - else: - return s + return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC') def _test(): From c5552070ee846b93e927264a52e13afed2f6664f Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 24 Oct 2017 12:06:50 -0700 Subject: [PATCH 32/34] added workaround for date casting, put back check for timestamp conversion, set timestamp cast flag for copy to false --- python/pyspark/serializers.py | 17 ++++++++++++----- python/pyspark/sql/types.py | 9 +++++++-- 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index a1ad7415ac00..d93fc6ac7a6c 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -214,7 +214,7 @@ def __repr__(self): def _create_batch(series): - from pyspark.sql.types import _series_convert_timestamps_internal + from pyspark.sql.types import _check_series_convert_timestamps_internal import pyarrow as pa # Make input conform to [(series1, type1), (series2, type2), ...] if not isinstance(series, (list, tuple)) or \ @@ -227,16 +227,23 @@ def _create_batch(series): def cast_series(s, t): if type(t) == pa.TimestampType: # NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680 - return _series_convert_timestamps_internal(s).values.astype('datetime64[us]') + return _check_series_convert_timestamps_internal(s)\ + .values.astype('datetime64[us]', copy=False) elif t == pa.date32(): - # TODO: ValueError: Cannot cast DatetimeIndex to dtype datetime64[D] - return s.dt.values.astype('datetime64[D]') + # TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8 + return s.dt.date elif t is None or s.dtype == t.to_pandas_dtype(): return s else: return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) - arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series] + # Some object types don't support masks in Arrow, see ARROW-1721 + def create_array(s, t): + casted = cast_series(s, t) + mask = None if casted.dtype == 'object' else s.isnull() + return pa.Array.from_pandas(casted, mask=mask, type=t) + + arrs = [create_array(s, t) for s, t in series] return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index a70e7bbe6096..3921b05d47a1 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1640,10 +1640,15 @@ def _check_dataframe_localize_timestamps(df): return df -def _series_convert_timestamps_internal(s): +def _check_series_convert_timestamps_internal(s): """ Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage """ - return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC') + from pandas.api.types import is_datetime64_dtype + # TODO: handle nested timestamps, such as ArrayType(TimestampType())? + if is_datetime64_dtype(s.dtype): + return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC') + else: + return s def _test(): From 4d4089330d451bf6a145c28a6f34407ce3138b4d Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 25 Oct 2017 10:21:00 -0700 Subject: [PATCH 33/34] added fillna for null timestamp values --- python/pyspark/serializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index d93fc6ac7a6c..d7979f095da7 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -227,7 +227,7 @@ def _create_batch(series): def cast_series(s, t): if type(t) == pa.TimestampType: # NOTE: convert to 'us' with astype here, unit ignored in `from_pandas` see ARROW-1680 - return _check_series_convert_timestamps_internal(s)\ + return _check_series_convert_timestamps_internal(s.fillna(0))\ .values.astype('datetime64[us]', copy=False) elif t == pa.date32(): # TODO: this converts the series to Python objects, possibly avoid with Arrow >= 0.8 From addd35f49d58227dd59b7d9d3595403cb992c1a8 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 26 Oct 2017 10:11:26 -0700 Subject: [PATCH 34/34] added check for pandas_udf return is a timestamp with tz, added comments on conversion function input and output --- python/pyspark/sql/dataframe.py | 4 ++-- python/pyspark/sql/tests.py | 18 ++++++++++++++++++ python/pyspark/sql/types.py | 23 ++++++++++++++++------- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 38fa80438256..5ad53cff3cf6 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -1885,8 +1885,8 @@ def toPandas(self): tables = self._collectAsArrow() if tables: table = pyarrow.concat_tables(tables) - df = table.to_pandas() - return _check_dataframe_localize_timestamps(df) + pdf = table.to_pandas() + return _check_dataframe_localize_timestamps(pdf) else: return pd.DataFrame.from_records([], columns=self.columns) except ImportError as e: diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 0b8517c619d9..98afae662b42 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3458,6 +3458,24 @@ def check_data(idx, timestamp, timestamp_copy): self.assertEquals(data[i][1], result[i][1]) # "timestamp" col self.assertTrue(result[i][3]) # "is_equal" data in udf was as expected + def test_vectorized_udf_return_timestamp_tz(self): + from pyspark.sql.functions import pandas_udf, col + import pandas as pd + df = self.spark.range(10) + + @pandas_udf(returnType=TimestampType()) + def gen_timestamps(id): + ts = [pd.Timestamp(i, unit='D', tz='America/Los_Angeles') for i in id] + return pd.Series(ts) + + result = df.withColumn("ts", gen_timestamps(col("id"))).collect() + spark_ts_t = TimestampType() + for r in result: + i, ts = r + ts_tz = pd.Timestamp(i, unit='D', tz='America/Los_Angeles').to_pydatetime() + expected = spark_ts_t.fromInternal(spark_ts_t.toInternal(ts_tz)) + self.assertEquals(expected, ts) + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedPySparkTestCase): diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 3921b05d47a1..7dd8fa04160e 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1629,24 +1629,33 @@ def to_arrow_type(dt): return arrow_type -def _check_dataframe_localize_timestamps(df): - """ Convert timezone aware timestamps to timezone-naive in local time +def _check_dataframe_localize_timestamps(pdf): + """ + Convert timezone aware timestamps to timezone-naive in local time + + :param pdf: pandas.DataFrame + :return pandas.DataFrame where any timezone aware columns have be converted to tz-naive """ from pandas.api.types import is_datetime64tz_dtype - for column, series in df.iteritems(): + for column, series in pdf.iteritems(): # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64tz_dtype(series.dtype): - df[column] = series.dt.tz_convert('tzlocal()').dt.tz_localize(None) - return df + pdf[column] = series.dt.tz_convert('tzlocal()').dt.tz_localize(None) + return pdf def _check_series_convert_timestamps_internal(s): - """ Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage """ - from pandas.api.types import is_datetime64_dtype + Convert a tz-naive timestamp in local tz to UTC normalized for Spark internal storage + :param s: a pandas.Series + :return pandas.Series where if it is a timestamp, has been UTC normalized without a time zone + """ + from pandas.api.types import is_datetime64_dtype, is_datetime64tz_dtype # TODO: handle nested timestamps, such as ArrayType(TimestampType())? if is_datetime64_dtype(s.dtype): return s.dt.tz_localize('tzlocal()').dt.tz_convert('UTC') + elif is_datetime64tz_dtype(s.dtype): + return s.dt.tz_convert('UTC') else: return s