From 45fe2146f18d3aa66238ecabeed7896458b84975 Mon Sep 17 00:00:00 2001 From: scwf Date: Sat, 9 May 2015 21:58:03 +0800 Subject: [PATCH 1/4] Clean up all the inbound/outbound conversions for DateType --- .../sql/catalyst/expressions/UnsafeRow.java | 7 +++++++ .../main/scala/org/apache/spark/sql/Row.scala | 5 ++--- .../expressions/SpecificMutableRow.scala | 9 +++++++++ .../expressions/codegen/CodeGenerator.scala | 1 + .../codegen/GenerateProjection.scala | 18 ++++++++++++++---- .../spark/sql/catalyst/expressions/rows.scala | 10 +++++++--- 6 files changed, 40 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index bb546b3086b3..968abfb0fd08 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions; +import org.apache.spark.sql.types.DateUtils; import scala.collection.Map; import scala.collection.Seq; import scala.collection.mutable.ArraySeq; @@ -217,6 +218,12 @@ public void setFloat(int ordinal, float value) { PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value); } + + @Override + public void setDate(int ordinal, Date value) { + setInt(ordinal, DateUtils.fromJavaDate(value)); + } + @Override public void setString(int ordinal, String value) { throw new UnsupportedOperationException(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 4190b7ffe1c8..ecf9689f89bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql import scala.util.hashing.MurmurHash3 import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DateUtils, StructType} object Row { /** @@ -257,8 +257,7 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - // TODO(davies): This is not the right default implementation, we use Int as Date internally - def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date] + def getDate(i: Int): java.sql.Date = DateUtils.toJavaDate(getInt(i)) /** * Returns the value at position i of array type as a Scala Seq. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index aa4099e4d7bf..0f88b4a577d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -313,7 +313,16 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR values(i).asInstanceOf[MutableByte].value } + override def setDate(ordinal: Int, value: java.sql.Date): Unit = { + setInt(ordinal, DateUtils.fromJavaDate(value)) + } + + override def getDate(i: Int): java.sql.Date = { + DateUtils.toJavaDate(values(i).asInstanceOf[MutableInt].value) + } + override def getAs[T](i: Int): T = { values(i).boxed.asInstanceOf[T] } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index d17af0e7ff87..a0c6a644576f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -660,6 +660,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin case FloatType => "Float" case BooleanType => "Boolean" case StringType => "org.apache.spark.sql.types.UTF8String" + case DateType => "Int" } protected def defaultPrimitive(dt: DataType) = dt match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 584f938445c8..f3865aa93b5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -111,8 +111,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val specificAccessorFunctions = nativeTypes.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { - // getString() is not used by expressions - case (e, i) if e.dataType == dataType && dataType != StringType => + // getString() and getDate are not used by expressions + case (e, i) if e.dataType == dataType && dataType != StringType && dataType != DateType => val elementName = newTermName(s"c$i") // TODO: The string of ifs gets pretty inefficient as the row grows in size. // TODO: Optional null checks? @@ -126,6 +126,11 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { override def getString(i: Int): String = { $accessorFailure }""" + case DateType => + q""" + override def getDate(i: Int): java.sql.Date = { + $accessorFailure + }""" case other => q""" override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = { @@ -137,8 +142,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { val specificMutatorFunctions = nativeTypes.map { dataType => val ifStatements = expressions.zipWithIndex.flatMap { - // setString() is not used by expressions - case (e, i) if e.dataType == dataType && dataType != StringType => + // setString() and setDate are not used by expressions + case (e, i) if e.dataType == dataType && dataType != StringType && dataType != DateType => val elementName = newTermName(s"c$i") // TODO: The string of ifs gets pretty inefficient as the row grows in size. // TODO: Optional null checks? @@ -152,6 +157,11 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { override def setString(i: Int, value: String) { $accessorFailure }""" + case DateType => + q""" + override def setDate(i: Int, value: java.sql.Date) { + $accessorFailure + }""" case other => q""" override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 5fd892c42e69..ff09ee7a51e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.{UTF8String, DataType, StructType, AtomicType} +import org.apache.spark.sql.types._ /** * An extended interface to [[Row]] that allows the values for each column to be updated. Setting @@ -36,7 +36,8 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) - // TODO(davies): add setDate() and setDecimal() + def setDate(ordinal:Int, value: java.sql.Date) + // TODO(davies): add setDecimal() } /** @@ -121,7 +122,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } } - // TODO(davies): add getDate and getDecimal + // TODO(davies): add getDecimal // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { @@ -199,6 +200,9 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } override def setString(ordinal: Int, value: String) { values(ordinal) = UTF8String(value)} override def setNullAt(i: Int): Unit = { values(i) = null } + override def setDate(ordinal:Int, value: java.sql.Date): Unit = { + values(ordinal) = DateUtils.fromJavaDate(value) + } override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } From 2606a5698007eeaedc374361bb2378c24be5872b Mon Sep 17 00:00:00 2001 From: scwf Date: Sat, 9 May 2015 22:04:36 +0800 Subject: [PATCH 2/4] fix tests --- .../spark/sql/catalyst/expressions/UnsafeRow.java | 2 +- .../catalyst/expressions/codegen/CodeGenerator.scala | 11 ++++++++++- .../apache/spark/sql/catalyst/expressions/rows.scala | 1 + 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 968abfb0fd08..380cf7893f1a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions; -import org.apache.spark.sql.types.DateUtils; import scala.collection.Map; import scala.collection.Seq; import scala.collection.mutable.ArraySeq; @@ -29,6 +28,7 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateUtils; import static org.apache.spark.sql.types.DataTypes.*; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.UTF8String; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index a0c6a644576f..d060327000d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -686,7 +686,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin * List of data types that have special accessors and setters in [[Row]]. */ protected val nativeTypes = - Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + Seq( + IntegerType, + BooleanType, + LongType, + DoubleType, + FloatType, + ShortType, + ByteType, + StringType, + DateType) /** * Returns true if the data type has a special accessor and setter in [[Row]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index ff09ee7a51e5..52a207a8f6e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -56,6 +56,7 @@ object EmptyRow extends Row { override def getShort(i: Int): Short = throw new UnsupportedOperationException override def getByte(i: Int): Byte = throw new UnsupportedOperationException override def getString(i: Int): String = throw new UnsupportedOperationException + override def getDate(i: Int): java.sql.Date = throw new UnsupportedOperationException override def getAs[T](i: Int): T = throw new UnsupportedOperationException override def copy(): Row = this } From ae30c22cca9352742d13554fd3ad9e81c427971d Mon Sep 17 00:00:00 2001 From: scwf Date: Sat, 9 May 2015 22:11:55 +0800 Subject: [PATCH 3/4] style --- .../spark/sql/catalyst/expressions/SpecificMutableRow.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 0f88b4a577d8..b54bd60258d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -324,5 +324,4 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def getAs[T](i: Int): T = { values(i).boxed.asInstanceOf[T] } - } From 274b39e3077401818c6d7ac1304020af3e8afe95 Mon Sep 17 00:00:00 2001 From: scwf Date: Sun, 10 May 2015 08:14:09 +0800 Subject: [PATCH 4/4] added unit test for date support --- .../org/apache/spark/sql/SQLQuerySuite.scala | 42 ++++++++++++++++++- .../scala/org/apache/spark/sql/TestData.scala | 11 ++++- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b44eb223c80c..4f169e1222e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql +import java.sql.Date + import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} @@ -501,6 +502,45 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { ) } + test("date support") { + checkAnswer(sql( + "SELECT date FROM dates"), + Seq( + Row(Date.valueOf("1970-01-01")), + Row(Date.valueOf("1970-01-02")), + Row(Date.valueOf("1970-01-03")))) + + checkAnswer(sql( + "SELECT date FROM dates WHERE date=CAST('1970-01-01' AS date)"), + Row(Date.valueOf("1970-01-01"))) + + checkAnswer(sql( + "SELECT date FROM dates WHERE date='1970-01-01'"), + Row(Date.valueOf("1970-01-01"))) + + checkAnswer(sql( + "SELECT date FROM dates WHERE '1970-01-01'=date"), + Row(Date.valueOf("1970-01-01"))) + + checkAnswer(sql( + """SELECT date FROM dates WHERE date<'1970-01-03' + AND date>'1970-01-01'"""), + Row(Date.valueOf("1970-01-02"))) + + checkAnswer(sql( + """ + |SELECT date FROM dates + |WHERE date IN ('1970-01-01','1970-01-02') + """.stripMargin), + Seq( + Row(Date.valueOf("1970-01-01")), + Row(Date.valueOf("1970-01-02")))) + + checkAnswer(sql( + "SELECT date FROM dates WHERE date='123'"), + Nil) + } + test("from follow multiple brackets") { checkAnswer(sql( """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 225b51bd73d6..df038beb8e6a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.sql.Timestamp +import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.test.TestSQLContext.implicits._ @@ -176,6 +176,15 @@ object TestData { "3, C3, true, null" :: "4, D4, true, 2147483644" :: Nil) + case class DateField(date: java.sql.Date) + val dates = TestSQLContext.sparkContext.parallelize( + Seq( + Date.valueOf("1970-01-01"), + Date.valueOf("1970-01-02"), + Date.valueOf("1970-01-03")).map(DateField(_)) + ) + dates.toDF().registerTempTable("dates") + case class TimestampField(time: Timestamp) val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i => TimestampField(new Timestamp(i))