diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index ef1d12531f109..7dff9a26e8c98 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Timestamp /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. @@ -137,6 +138,15 @@ class JoinedRow extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + def getDecimal(i: Int): BigDecimal = + if (i < row1.size) row1.getDecimal(i) else row2.getDecimal(i - row1.size) + + def getTimestamp(i: Int): Timestamp = + if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) + + def getBinary(i: Int): Array[Byte] = + if (i < row1.size) row1.getBinary(i) else row2.getBinary(i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -226,6 +236,15 @@ class JoinedRow2 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + def getDecimal(i: Int): BigDecimal = + if (i < row1.size) row1.getDecimal(i) else row2.getDecimal(i - row1.size) + + def getTimestamp(i: Int): Timestamp = + if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) + + def getBinary(i: Int): Array[Byte] = + if (i < row1.size) row1.getBinary(i) else row2.getBinary(i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -309,6 +328,15 @@ class JoinedRow3 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + def getDecimal(i: Int): BigDecimal = + if (i < row1.size) row1.getDecimal(i) else row2.getDecimal(i - row1.size) + + def getTimestamp(i: Int): Timestamp = + if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) + + def getBinary(i: Int): Array[Byte] = + if (i < row1.size) row1.getBinary(i) else row2.getBinary(i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -392,6 +420,15 @@ class JoinedRow4 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + def getDecimal(i: Int): BigDecimal = + if (i < row1.size) row1.getDecimal(i) else row2.getDecimal(i - row1.size) + + def getTimestamp(i: Int): Timestamp = + if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) + + def getBinary(i: Int): Array[Byte] = + if (i < row1.size) row1.getBinary(i) else row2.getBinary(i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) @@ -475,6 +512,15 @@ class JoinedRow5 extends Row { def getString(i: Int): String = if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size) + def getDecimal(i: Int): BigDecimal = + if (i < row1.size) row1.getDecimal(i) else row2.getDecimal(i - row1.size) + + def getTimestamp(i: Int): Timestamp = + if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size) + + def getBinary(i: Int): Array[Byte] = + if (i < row1.size) row1.getBinary(i) else row2.getBinary(i - row1.size) + def copy() = { val totalSize = row1.size + row2.size val copiedValues = new Array[Any](totalSize) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala index d68a4fabeac77..041013c3f27d1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.types.NativeType object Row { @@ -64,6 +66,9 @@ trait Row extends Seq[Any] with Serializable { def getShort(i: Int): Short def getByte(i: Int): Byte def getString(i: Int): String + def getDecimal(i: Int): BigDecimal + def getTimestamp(i: Int): Timestamp + def getBinary(i: Int): Array[Byte] override def toString() = s"[${this.mkString(",")}]" @@ -98,6 +103,9 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) + def setDecimal(ordinal: Int, value: BigDecimal) + def setTimestamp(ordinal: Int, value: Timestamp) + def setBinary(ordinal: Int, value: Array[Byte]) } /** @@ -118,6 +126,9 @@ object EmptyRow extends Row { def getShort(i: Int): Short = throw new UnsupportedOperationException def getByte(i: Int): Byte = throw new UnsupportedOperationException def getString(i: Int): String = throw new UnsupportedOperationException + def getDecimal(i: Int): BigDecimal = throw new UnsupportedOperationException + def getTimestamp(i: Int): Timestamp = throw new UnsupportedOperationException + def getBinary(i: Int): Array[Byte] = throw new UnsupportedOperationException def copy() = this } @@ -181,6 +192,21 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { values(i).asInstanceOf[String] } + def getDecimal(i: Int): BigDecimal = { + if (values(i) == null) sys.error("Failed to check null bit for primitive Decimal value.") + values(i).asInstanceOf[BigDecimal] + } + + def getTimestamp(i: Int): Timestamp = { + if (values(i) == null) sys.error("Failed to check null bit for primitive Timestamp value.") + values(i).asInstanceOf[Timestamp] + } + + def getBinary(i: Int): Array[Byte] = { + if (values(i) == null) sys.error("Failed to check null bit for primitive Binary value.") + values(i).asInstanceOf[Array[Byte]] + } + // Custom hashCode function that matches the efficient code generated version. override def hashCode(): Int = { var result: Int = 37 @@ -201,6 +227,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { case d: Double => val b = java.lang.Double.doubleToLongBits(d) (b ^ (b >>> 32)).toInt + case b: Array[Byte] => 123 // TODO need to figure out how to compute the hashcode case other => other.hashCode() } } @@ -224,6 +251,9 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow { override def setInt(ordinal: Int,value: Int): Unit = { values(ordinal) = value } override def setLong(ordinal: Int,value: Long): Unit = { values(ordinal) = value } override def setString(ordinal: Int,value: String): Unit = { values(ordinal) = value } + override def setDecimal(ordinal: Int, value: BigDecimal): Unit = { values(ordinal) = value } + override def setTimestamp(ordinal: Int, value: Timestamp): Unit = { values(ordinal) = value } + override def setBinary(ordinal: Int, value: Array[Byte]): Unit = { values(ordinal) = value } override def setNullAt(i: Int): Unit = { values(i) = null } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala index 75ea0e8459df8..7005acf331ea6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.Timestamp + import org.apache.spark.sql.catalyst.types._ /** @@ -231,9 +233,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def iterator: Iterator[Any] = values.map(_.boxed).iterator - def setString(ordinal: Int, value: String) = update(ordinal, value) + override def setString(ordinal: Int, value: String) = update(ordinal, value) - def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] + override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String] override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] @@ -304,4 +306,16 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def getByte(i: Int): Byte = { values(i).asInstanceOf[MutableByte].value } + + override def setDecimal(ordinal: Int, value: BigDecimal): Unit = update(ordinal, value) + + override def getDecimal(i: Int): BigDecimal = apply(i).asInstanceOf[BigDecimal] + + override def setTimestamp(ordinal: Int, value: Timestamp): Unit = update(ordinal, value) + + override def getTimestamp(i: Int): Timestamp = apply(i).asInstanceOf[Timestamp] + + override def setBinary(ordinal: Int, value: Array[Byte]): Unit = update(ordinal, value) + + override def getBinary(i: Int): Array[Byte] = apply(i).asInstanceOf[Array[Byte]] }