Skip to content

Commit 3644ffa

Browse files
Add 3 missing types for Row API
1 parent 6a37ed8 commit 3644ffa

File tree

3 files changed

+92
-2
lines changed

3 files changed

+92
-2
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import java.sql.Timestamp
2021

2122
/**
2223
* A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
@@ -137,6 +138,15 @@ class JoinedRow extends Row {
137138
def getString(i: Int): String =
138139
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
139140

141+
def getDecimal(i: Int): BigDecimal =
142+
if (i < row1.size) row1.getDecimal(i) else row2.getDecimal(i - row1.size)
143+
144+
def getTimestamp(i: Int): Timestamp =
145+
if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size)
146+
147+
def getBinary(i: Int): Array[Byte] =
148+
if (i < row1.size) row1.getBinary(i) else row2.getBinary(i - row1.size)
149+
140150
def copy() = {
141151
val totalSize = row1.size + row2.size
142152
val copiedValues = new Array[Any](totalSize)
@@ -226,6 +236,15 @@ class JoinedRow2 extends Row {
226236
def getString(i: Int): String =
227237
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
228238

239+
def getDecimal(i: Int): BigDecimal =
240+
if (i < row1.size) row1.getDecimal(i) else row2.getDecimal(i - row1.size)
241+
242+
def getTimestamp(i: Int): Timestamp =
243+
if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size)
244+
245+
def getBinary(i: Int): Array[Byte] =
246+
if (i < row1.size) row1.getBinary(i) else row2.getBinary(i - row1.size)
247+
229248
def copy() = {
230249
val totalSize = row1.size + row2.size
231250
val copiedValues = new Array[Any](totalSize)
@@ -309,6 +328,15 @@ class JoinedRow3 extends Row {
309328
def getString(i: Int): String =
310329
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
311330

331+
def getDecimal(i: Int): BigDecimal =
332+
if (i < row1.size) row1.getDecimal(i) else row2.getDecimal(i - row1.size)
333+
334+
def getTimestamp(i: Int): Timestamp =
335+
if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size)
336+
337+
def getBinary(i: Int): Array[Byte] =
338+
if (i < row1.size) row1.getBinary(i) else row2.getBinary(i - row1.size)
339+
312340
def copy() = {
313341
val totalSize = row1.size + row2.size
314342
val copiedValues = new Array[Any](totalSize)
@@ -392,6 +420,15 @@ class JoinedRow4 extends Row {
392420
def getString(i: Int): String =
393421
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
394422

423+
def getDecimal(i: Int): BigDecimal =
424+
if (i < row1.size) row1.getDecimal(i) else row2.getDecimal(i - row1.size)
425+
426+
def getTimestamp(i: Int): Timestamp =
427+
if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size)
428+
429+
def getBinary(i: Int): Array[Byte] =
430+
if (i < row1.size) row1.getBinary(i) else row2.getBinary(i - row1.size)
431+
395432
def copy() = {
396433
val totalSize = row1.size + row2.size
397434
val copiedValues = new Array[Any](totalSize)
@@ -475,6 +512,15 @@ class JoinedRow5 extends Row {
475512
def getString(i: Int): String =
476513
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
477514

515+
def getDecimal(i: Int): BigDecimal =
516+
if (i < row1.size) row1.getDecimal(i) else row2.getDecimal(i - row1.size)
517+
518+
def getTimestamp(i: Int): Timestamp =
519+
if (i < row1.size) row1.getTimestamp(i) else row2.getTimestamp(i - row1.size)
520+
521+
def getBinary(i: Int): Array[Byte] =
522+
if (i < row1.size) row1.getBinary(i) else row2.getBinary(i - row1.size)
523+
478524
def copy() = {
479525
val totalSize = row1.size + row2.size
480526
val copiedValues = new Array[Any](totalSize)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import java.sql.Timestamp
21+
2022
import org.apache.spark.sql.catalyst.types.NativeType
2123

2224
object Row {
@@ -64,6 +66,9 @@ trait Row extends Seq[Any] with Serializable {
6466
def getShort(i: Int): Short
6567
def getByte(i: Int): Byte
6668
def getString(i: Int): String
69+
def getDecimal(i: Int): BigDecimal
70+
def getTimestamp(i: Int): Timestamp
71+
def getBinary(i: Int): Array[Byte]
6772

6873
override def toString() =
6974
s"[${this.mkString(",")}]"
@@ -98,6 +103,9 @@ trait MutableRow extends Row {
98103
def setByte(ordinal: Int, value: Byte)
99104
def setFloat(ordinal: Int, value: Float)
100105
def setString(ordinal: Int, value: String)
106+
def setDecimal(ordinal: Int, value: BigDecimal)
107+
def setTimestamp(ordinal: Int, value: Timestamp)
108+
def setBinary(ordinal: Int, value: Array[Byte])
101109
}
102110

103111
/**
@@ -118,6 +126,9 @@ object EmptyRow extends Row {
118126
def getShort(i: Int): Short = throw new UnsupportedOperationException
119127
def getByte(i: Int): Byte = throw new UnsupportedOperationException
120128
def getString(i: Int): String = throw new UnsupportedOperationException
129+
def getDecimal(i: Int): BigDecimal = throw new UnsupportedOperationException
130+
def getTimestamp(i: Int): Timestamp = throw new UnsupportedOperationException
131+
def getBinary(i: Int): Array[Byte] = throw new UnsupportedOperationException
121132

122133
def copy() = this
123134
}
@@ -181,6 +192,21 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
181192
values(i).asInstanceOf[String]
182193
}
183194

195+
def getDecimal(i: Int): BigDecimal = {
196+
if (values(i) == null) sys.error("Failed to check null bit for primitive Decimal value.")
197+
values(i).asInstanceOf[BigDecimal]
198+
}
199+
200+
def getTimestamp(i: Int): Timestamp = {
201+
if (values(i) == null) sys.error("Failed to check null bit for primitive Timestamp value.")
202+
values(i).asInstanceOf[Timestamp]
203+
}
204+
205+
def getBinary(i: Int): Array[Byte] = {
206+
if (values(i) == null) sys.error("Failed to check null bit for primitive Binary value.")
207+
values(i).asInstanceOf[Array[Byte]]
208+
}
209+
184210
// Custom hashCode function that matches the efficient code generated version.
185211
override def hashCode(): Int = {
186212
var result: Int = 37
@@ -201,6 +227,7 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
201227
case d: Double =>
202228
val b = java.lang.Double.doubleToLongBits(d)
203229
(b ^ (b >>> 32)).toInt
230+
case b: Array[Byte] => 123 // TODO need to figure out how to compute the hashcode
204231
case other => other.hashCode()
205232
}
206233
}
@@ -224,6 +251,9 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow {
224251
override def setInt(ordinal: Int,value: Int): Unit = { values(ordinal) = value }
225252
override def setLong(ordinal: Int,value: Long): Unit = { values(ordinal) = value }
226253
override def setString(ordinal: Int,value: String): Unit = { values(ordinal) = value }
254+
override def setDecimal(ordinal: Int, value: BigDecimal): Unit = { values(ordinal) = value }
255+
override def setTimestamp(ordinal: Int, value: Timestamp): Unit = { values(ordinal) = value }
256+
override def setBinary(ordinal: Int, value: Array[Byte]): Unit = { values(ordinal) = value }
227257

228258
override def setNullAt(i: Int): Unit = { values(i) = null }
229259

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import java.sql.Timestamp
21+
2022
import org.apache.spark.sql.catalyst.types._
2123

2224
/**
@@ -231,9 +233,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
231233

232234
override def iterator: Iterator[Any] = values.map(_.boxed).iterator
233235

234-
def setString(ordinal: Int, value: String) = update(ordinal, value)
236+
override def setString(ordinal: Int, value: String) = update(ordinal, value)
235237

236-
def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]
238+
override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]
237239

238240
override def setInt(ordinal: Int, value: Int): Unit = {
239241
val currentValue = values(ordinal).asInstanceOf[MutableInt]
@@ -304,4 +306,16 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
304306
override def getByte(i: Int): Byte = {
305307
values(i).asInstanceOf[MutableByte].value
306308
}
309+
310+
override def setDecimal(ordinal: Int, value: BigDecimal): Unit = update(ordinal, value)
311+
312+
override def getDecimal(i: Int): BigDecimal = apply(i).asInstanceOf[BigDecimal]
313+
314+
override def setTimestamp(ordinal: Int, value: Timestamp): Unit = update(ordinal, value)
315+
316+
override def getTimestamp(i: Int): Timestamp = apply(i).asInstanceOf[Timestamp]
317+
318+
override def setBinary(ordinal: Int, value: Array[Byte]): Unit = update(ordinal, value)
319+
320+
override def getBinary(i: Int): Array[Byte] = apply(i).asInstanceOf[Array[Byte]]
307321
}

0 commit comments

Comments
 (0)