From 34231e47511fbb249e3a36d0753683f147588edf Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 1 Apr 2019 23:13:17 +0200 Subject: [PATCH 1/6] Row getters for LocalDate and Instant --- .../src/main/scala/org/apache/spark/sql/Row.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 4f5af9ac80b1..f13eddee77e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -269,6 +269,13 @@ trait Row extends Serializable { */ def getDate(i: Int): java.sql.Date = getAs[java.sql.Date](i) + /** + * Returns the value at position i of date type as java.time.LocalDate. + * + * @throws ClassCastException when data type does not match. + */ + def getLocalDate(i: Int): java.time.LocalDate = getAs[java.time.LocalDate](i) + /** * Returns the value at position i of date type as java.sql.Timestamp. * @@ -276,6 +283,13 @@ trait Row extends Serializable { */ def getTimestamp(i: Int): java.sql.Timestamp = getAs[java.sql.Timestamp](i) + /** + * Returns the value at position i of date type as java.time.Instant. + * + * @throws ClassCastException when data type does not match. + */ + def getInstant(i: Int): java.time.Instant = getAs[java.time.Instant](i) + /** * Returns the value at position i of array type as a Scala Seq. * From 506c9e7d30b45ab10744a15c11bae392cf17cd5f Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Mon, 1 Apr 2019 23:26:19 +0200 Subject: [PATCH 2/6] Test for localDate and Instant --- .../sql/JavaBeanDeserializationSuite.java | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index c5f38676ad0a..b0834184c4cb 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -18,6 +18,8 @@ package test.org.apache.spark.sql; import java.io.Serializable; +import java.time.Instant; +import java.time.LocalDate; import java.util.*; import org.apache.spark.sql.*; @@ -509,4 +511,86 @@ public void setId(Integer id) { this.id = id; } } + + @Test + public void testSpark30() { + List inputRows = new ArrayList<>(); + List expectedRecords = new ArrayList<>(); + + for (long idx = 0 ; idx < 5 ; idx++) { + Row row = createRecordSpark30Row(idx); + inputRows.add(row); + expectedRecords.add(createRecordSpark30(row)); + } + + Encoder encoder = Encoders.bean(RecordSpark30.class); + + StructType schema = new StructType() + .add("localDateField", DataTypes.DateType) + .add("instantField", DataTypes.TimestampType); + + Dataset dataFrame = spark.createDataFrame(inputRows, schema); + Dataset dataset = dataFrame.as(encoder); + + List records = dataset.collectAsList(); + + Assert.assertEquals(expectedRecords, records); + } + + public static final class RecordSpark30 { + private String localDateField; + private String instantField; + + public RecordSpark30() { } + + public String getLocalDateField() { + return localDateField; + } + + public void setLocalDateField(String localDateField) { + this.localDateField = localDateField; + } + + public String getInstantField() { + return instantField; + } + + public void setInstantField(String instantField) { + this.instantField = instantField; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + RecordSpark30 that = (RecordSpark30) o; + return Objects.equals(localDateField, that.localDateField) && + Objects.equals(instantField, that.instantField); + } + + @Override + public int hashCode() { + return Objects.hash(localDateField, instantField); + } + + @Override + public String toString() { + return com.google.common.base.Objects.toStringHelper(this) + .add("localDateField", localDateField) + .add("instantField", instantField) + .toString(); + } + } + + private static Row createRecordSpark30Row(Long index) { + Object[] values = new Object[] { LocalDate.ofEpochDay(42), Instant.ofEpochSecond(42) }; + return new GenericRow(values); + } + + private static RecordSpark30 createRecordSpark30(Row recordRow) { + RecordSpark30 record = new RecordSpark30(); + record.setLocalDateField(String.valueOf(recordRow.getLocalDate(0))); + record.setInstantField(String.valueOf(recordRow.getInstant(1))); + return record; + } } From 9442afdb2852c1201e19bf23161e60766286e444 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 2 Apr 2019 00:45:19 +0200 Subject: [PATCH 3/6] Fix inferDataType --- .../scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 39132139237c..c5be3efc6371 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -102,7 +102,9 @@ object JavaTypeInference { case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType.SYSTEM_DEFAULT, true) case c: Class[_] if c == classOf[java.math.BigInteger] => (DecimalType.BigIntDecimal, true) + case c: Class[_] if c == classOf[java.time.LocalDate] => (DateType, true) case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true) + case c: Class[_] if c == classOf[java.time.Instant] => (TimestampType, true) case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true) case _ if typeToken.isArray => From eeebaad88a7a07c61aba94538b415983a74fa603 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 2 Apr 2019 00:45:59 +0200 Subject: [PATCH 4/6] Fix the test --- .../sql/JavaBeanDeserializationSuite.java | 46 ++++++++++++------- 1 file changed, 29 insertions(+), 17 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index b0834184c4cb..f7092dbd0af7 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -24,6 +24,9 @@ import org.apache.spark.sql.*; import org.apache.spark.sql.catalyst.expressions.GenericRow; +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.catalyst.util.TimestampFormatter; +import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructType; import org.junit.*; @@ -514,27 +517,33 @@ public void setId(Integer id) { @Test public void testSpark30() { - List inputRows = new ArrayList<>(); - List expectedRecords = new ArrayList<>(); - - for (long idx = 0 ; idx < 5 ; idx++) { - Row row = createRecordSpark30Row(idx); - inputRows.add(row); - expectedRecords.add(createRecordSpark30(row)); - } + String originConf = spark.conf().get(SQLConf.DATETIME_JAVA8API_ENABLED().key()); + try { + spark.conf().set(SQLConf.DATETIME_JAVA8API_ENABLED().key(), "true"); + List inputRows = new ArrayList<>(); + List expectedRecords = new ArrayList<>(); + + for (long idx = 0 ; idx < 5 ; idx++) { + Row row = createRecordSpark30Row(idx); + inputRows.add(row); + expectedRecords.add(createRecordSpark30(row)); + } - Encoder encoder = Encoders.bean(RecordSpark30.class); + Encoder encoder = Encoders.bean(RecordSpark30.class); - StructType schema = new StructType() - .add("localDateField", DataTypes.DateType) - .add("instantField", DataTypes.TimestampType); + StructType schema = new StructType() + .add("localDateField", DataTypes.DateType) + .add("instantField", DataTypes.TimestampType); - Dataset dataFrame = spark.createDataFrame(inputRows, schema); - Dataset dataset = dataFrame.as(encoder); + Dataset dataFrame = spark.createDataFrame(inputRows, schema); + Dataset dataset = dataFrame.as(encoder); - List records = dataset.collectAsList(); + List records = dataset.collectAsList(); - Assert.assertEquals(expectedRecords, records); + Assert.assertEquals(expectedRecords, records); + } finally { + spark.conf().set(SQLConf.DATETIME_JAVA8API_ENABLED().key(), originConf); + } } public static final class RecordSpark30 { @@ -590,7 +599,10 @@ private static Row createRecordSpark30Row(Long index) { private static RecordSpark30 createRecordSpark30(Row recordRow) { RecordSpark30 record = new RecordSpark30(); record.setLocalDateField(String.valueOf(recordRow.getLocalDate(0))); - record.setInstantField(String.valueOf(recordRow.getInstant(1))); + Instant instant = recordRow.getInstant(1); + TimestampFormatter formatter = TimestampFormatter.getFractionFormatter( + DateTimeUtils.getZoneId(SQLConf.get().sessionLocalTimeZone())); + record.setInstantField(formatter.format(DateTimeUtils.instantToMicros(instant))); return record; } } From 2e366b9ede9d079af5741d8e848def21aae3a15c Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 2 Apr 2019 08:11:19 +0200 Subject: [PATCH 5/6] Update the comment for Java Bean encoder --- sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index f54c6920ce82..055fbc49bdcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -149,7 +149,7 @@ object Encoders { * - boxed types: Boolean, Integer, Double, etc. * - String * - java.math.BigDecimal, java.math.BigInteger - * - time related: java.sql.Date, java.sql.Timestamp + * - time related: java.sql.Date, java.sql.Timestamp, java.time.LocalDate, java.time.Instant * - collection types: only array and java.util.List currently, map support is in progress * - nested java bean. * From c39b0d5e2d0de277f97e9481418d59fdb0c995f5 Mon Sep 17 00:00:00 2001 From: Maxim Gekk Date: Tue, 2 Apr 2019 14:01:31 +0200 Subject: [PATCH 6/6] Refactoring --- .../sql/JavaBeanDeserializationSuite.java | 26 +++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java index f7092dbd0af7..7bf0789b43d6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaBeanDeserializationSuite.java @@ -516,29 +516,29 @@ public void setId(Integer id) { } @Test - public void testSpark30() { + public void testBeanWithLocalDateAndInstant() { String originConf = spark.conf().get(SQLConf.DATETIME_JAVA8API_ENABLED().key()); try { spark.conf().set(SQLConf.DATETIME_JAVA8API_ENABLED().key(), "true"); List inputRows = new ArrayList<>(); - List expectedRecords = new ArrayList<>(); + List expectedRecords = new ArrayList<>(); for (long idx = 0 ; idx < 5 ; idx++) { - Row row = createRecordSpark30Row(idx); + Row row = createLocalDateInstantRow(idx); inputRows.add(row); - expectedRecords.add(createRecordSpark30(row)); + expectedRecords.add(createLocalDateInstantRecord(row)); } - Encoder encoder = Encoders.bean(RecordSpark30.class); + Encoder encoder = Encoders.bean(LocalDateInstantRecord.class); StructType schema = new StructType() .add("localDateField", DataTypes.DateType) .add("instantField", DataTypes.TimestampType); Dataset dataFrame = spark.createDataFrame(inputRows, schema); - Dataset dataset = dataFrame.as(encoder); + Dataset dataset = dataFrame.as(encoder); - List records = dataset.collectAsList(); + List records = dataset.collectAsList(); Assert.assertEquals(expectedRecords, records); } finally { @@ -546,11 +546,11 @@ public void testSpark30() { } } - public static final class RecordSpark30 { + public static final class LocalDateInstantRecord { private String localDateField; private String instantField; - public RecordSpark30() { } + public LocalDateInstantRecord() { } public String getLocalDateField() { return localDateField; @@ -572,7 +572,7 @@ public void setInstantField(String instantField) { public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; - RecordSpark30 that = (RecordSpark30) o; + LocalDateInstantRecord that = (LocalDateInstantRecord) o; return Objects.equals(localDateField, that.localDateField) && Objects.equals(instantField, that.instantField); } @@ -591,13 +591,13 @@ public String toString() { } } - private static Row createRecordSpark30Row(Long index) { + private static Row createLocalDateInstantRow(Long index) { Object[] values = new Object[] { LocalDate.ofEpochDay(42), Instant.ofEpochSecond(42) }; return new GenericRow(values); } - private static RecordSpark30 createRecordSpark30(Row recordRow) { - RecordSpark30 record = new RecordSpark30(); + private static LocalDateInstantRecord createLocalDateInstantRecord(Row recordRow) { + LocalDateInstantRecord record = new LocalDateInstantRecord(); record.setLocalDateField(String.valueOf(recordRow.getLocalDate(0))); Instant instant = recordRow.getInstant(1); TimestampFormatter formatter = TimestampFormatter.getFractionFormatter(