Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DateUtils;
import static org.apache.spark.sql.types.DataTypes.*;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.UTF8String;
Expand Down Expand Up @@ -217,6 +218,12 @@ public void setFloat(int ordinal, float value) {
PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
}


@Override
public void setDate(int ordinal, Date value) {
setInt(ordinal, DateUtils.fromJavaDate(value));
}

@Override
public void setString(int ordinal, String value) {
throw new UnsupportedOperationException();
Expand Down
5 changes: 2 additions & 3 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.util.hashing.MurmurHash3

import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DateUtils, StructType}

object Row {
/**
Expand Down Expand Up @@ -257,8 +257,7 @@ trait Row extends Serializable {
*
* @throws ClassCastException when data type does not match.
*/
// TODO(davies): This is not the right default implementation, we use Int as Date internally
def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date]
def getDate(i: Int): java.sql.Date = DateUtils.toJavaDate(getInt(i))

/**
* Returns the value at position i of array type as a Scala Seq.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,14 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
values(i).asInstanceOf[MutableByte].value
}

override def setDate(ordinal: Int, value: java.sql.Date): Unit = {
setInt(ordinal, DateUtils.fromJavaDate(value))
}

override def getDate(i: Int): java.sql.Date = {
DateUtils.toJavaDate(values(i).asInstanceOf[MutableInt].value)
}

override def getAs[T](i: Int): T = {
values(i).boxed.asInstanceOf[T]
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
case FloatType => "Float"
case BooleanType => "Boolean"
case StringType => "org.apache.spark.sql.types.UTF8String"
case DateType => "Int"
}

protected def defaultPrimitive(dt: DataType) = dt match {
Expand All @@ -685,7 +686,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
* List of data types that have special accessors and setters in [[Row]].
*/
protected val nativeTypes =
Seq(IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType)
Seq(
IntegerType,
BooleanType,
LongType,
DoubleType,
FloatType,
ShortType,
ByteType,
StringType,
DateType)

/**
* Returns true if the data type has a special accessor and setter in [[Row]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {

val specificAccessorFunctions = nativeTypes.map { dataType =>
val ifStatements = expressions.zipWithIndex.flatMap {
// getString() is not used by expressions
case (e, i) if e.dataType == dataType && dataType != StringType =>
// getString() and getDate are not used by expressions
case (e, i) if e.dataType == dataType && dataType != StringType && dataType != DateType =>
val elementName = newTermName(s"c$i")
// TODO: The string of ifs gets pretty inefficient as the row grows in size.
// TODO: Optional null checks?
Expand All @@ -126,6 +126,11 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
override def getString(i: Int): String = {
$accessorFailure
}"""
case DateType =>
q"""
override def getDate(i: Int): java.sql.Date = {
$accessorFailure
}"""
case other =>
q"""
override def ${accessorForType(dataType)}(i: Int): ${termForType(dataType)} = {
Expand All @@ -137,8 +142,8 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {

val specificMutatorFunctions = nativeTypes.map { dataType =>
val ifStatements = expressions.zipWithIndex.flatMap {
// setString() is not used by expressions
case (e, i) if e.dataType == dataType && dataType != StringType =>
// setString() and setDate are not used by expressions
case (e, i) if e.dataType == dataType && dataType != StringType && dataType != DateType =>
val elementName = newTermName(s"c$i")
// TODO: The string of ifs gets pretty inefficient as the row grows in size.
// TODO: Optional null checks?
Expand All @@ -152,6 +157,11 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] {
override def setString(i: Int, value: String) {
$accessorFailure
}"""
case DateType =>
q"""
override def setDate(i: Int, value: java.sql.Date) {
$accessorFailure
}"""
case other =>
q"""
override def ${mutatorForType(dataType)}(i: Int, value: ${termForType(dataType)}) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.types.{UTF8String, DataType, StructType, AtomicType}
import org.apache.spark.sql.types._

/**
* An extended interface to [[Row]] that allows the values for each column to be updated. Setting
Expand All @@ -36,7 +36,8 @@ trait MutableRow extends Row {
def setByte(ordinal: Int, value: Byte)
def setFloat(ordinal: Int, value: Float)
def setString(ordinal: Int, value: String)
// TODO(davies): add setDate() and setDecimal()
def setDate(ordinal:Int, value: java.sql.Date)
// TODO(davies): add setDecimal()
}

/**
Expand All @@ -55,6 +56,7 @@ object EmptyRow extends Row {
override def getShort(i: Int): Short = throw new UnsupportedOperationException
override def getByte(i: Int): Byte = throw new UnsupportedOperationException
override def getString(i: Int): String = throw new UnsupportedOperationException
override def getDate(i: Int): java.sql.Date = throw new UnsupportedOperationException
override def getAs[T](i: Int): T = throw new UnsupportedOperationException
override def copy(): Row = this
}
Expand Down Expand Up @@ -121,7 +123,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
}
}

// TODO(davies): add getDate and getDecimal
// TODO(davies): add getDecimal

// Custom hashCode function that matches the efficient code generated version.
override def hashCode: Int = {
Expand Down Expand Up @@ -199,6 +201,9 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value }
override def setString(ordinal: Int, value: String) { values(ordinal) = UTF8String(value)}
override def setNullAt(i: Int): Unit = { values(i) = null }
override def setDate(ordinal:Int, value: java.sql.Date): Unit = {
values(ordinal) = DateUtils.fromJavaDate(value)
}

override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value }

Expand Down
42 changes: 41 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

package org.apache.spark.sql

import java.sql.Date

import org.scalatest.BeforeAndAfterAll

import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.catalyst.CatalystConf
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{udf => _, _}
Expand Down Expand Up @@ -501,6 +502,45 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
)
}

test("date support") {
checkAnswer(sql(
"SELECT date FROM dates"),
Seq(
Row(Date.valueOf("1970-01-01")),
Row(Date.valueOf("1970-01-02")),
Row(Date.valueOf("1970-01-03"))))

checkAnswer(sql(
"SELECT date FROM dates WHERE date=CAST('1970-01-01' AS date)"),
Row(Date.valueOf("1970-01-01")))

checkAnswer(sql(
"SELECT date FROM dates WHERE date='1970-01-01'"),
Row(Date.valueOf("1970-01-01")))

checkAnswer(sql(
"SELECT date FROM dates WHERE '1970-01-01'=date"),
Row(Date.valueOf("1970-01-01")))

checkAnswer(sql(
"""SELECT date FROM dates WHERE date<'1970-01-03'
AND date>'1970-01-01'"""),
Row(Date.valueOf("1970-01-02")))

checkAnswer(sql(
"""
|SELECT date FROM dates
|WHERE date IN ('1970-01-01','1970-01-02')
""".stripMargin),
Seq(
Row(Date.valueOf("1970-01-01")),
Row(Date.valueOf("1970-01-02"))))

checkAnswer(sql(
"SELECT date FROM dates WHERE date='123'"),
Nil)
}

test("from follow multiple brackets") {
checkAnswer(sql(
"""
Expand Down
11 changes: 10 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql

import java.sql.Timestamp
import java.sql.{Date, Timestamp}

import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.test.TestSQLContext.implicits._
Expand Down Expand Up @@ -176,6 +176,15 @@ object TestData {
"3, C3, true, null" ::
"4, D4, true, 2147483644" :: Nil)

case class DateField(date: java.sql.Date)
val dates = TestSQLContext.sparkContext.parallelize(
Seq(
Date.valueOf("1970-01-01"),
Date.valueOf("1970-01-02"),
Date.valueOf("1970-01-03")).map(DateField(_))
)
dates.toDF().registerTempTable("dates")

case class TimestampField(time: Timestamp)
val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i =>
TimestampField(new Timestamp(i))
Expand Down