Skip to content

Commit 0c0e09f

Browse files
adrian-wangmarmbrus
authored andcommitted
[SPARK-3412][SQL]add missing row api
chenghao-intel assigned this to me, check PR #2284 for previous discussion Author: Daoyuan Wang <[email protected]> Closes #2529 from adrian-wang/rowapi and squashes the following commits: c6594b2 [Daoyuan Wang] using boxed 7b7e6e3 [Daoyuan Wang] update pattern match 7a39456 [Daoyuan Wang] rename file and refresh getAs[T] 4c18c29 [Daoyuan Wang] remove setAs[T] and null judge 1614493 [Daoyuan Wang] add missing row api
1 parent 1c7f0ab commit 0c0e09f

File tree

3 files changed

+32
-11
lines changed

3 files changed

+32
-11
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ class JoinedRow extends Row {
137137
def getString(i: Int): String =
138138
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
139139

140+
override def getAs[T](i: Int): T =
141+
if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
142+
140143
def copy() = {
141144
val totalSize = row1.size + row2.size
142145
val copiedValues = new Array[Any](totalSize)
@@ -226,6 +229,9 @@ class JoinedRow2 extends Row {
226229
def getString(i: Int): String =
227230
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
228231

232+
override def getAs[T](i: Int): T =
233+
if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
234+
229235
def copy() = {
230236
val totalSize = row1.size + row2.size
231237
val copiedValues = new Array[Any](totalSize)
@@ -309,6 +315,9 @@ class JoinedRow3 extends Row {
309315
def getString(i: Int): String =
310316
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
311317

318+
override def getAs[T](i: Int): T =
319+
if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
320+
312321
def copy() = {
313322
val totalSize = row1.size + row2.size
314323
val copiedValues = new Array[Any](totalSize)
@@ -392,6 +401,9 @@ class JoinedRow4 extends Row {
392401
def getString(i: Int): String =
393402
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
394403

404+
override def getAs[T](i: Int): T =
405+
if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
406+
395407
def copy() = {
396408
val totalSize = row1.size + row2.size
397409
val copiedValues = new Array[Any](totalSize)
@@ -475,6 +487,9 @@ class JoinedRow5 extends Row {
475487
def getString(i: Int): String =
476488
if (i < row1.size) row1.getString(i) else row2.getString(i - row1.size)
477489

490+
override def getAs[T](i: Int): T =
491+
if (i < row1.size) row1.getAs[T](i) else row2.getAs[T](i - row1.size)
492+
478493
def copy() = {
479494
val totalSize = row1.size + row2.size
480495
val copiedValues = new Array[Any](totalSize)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Row.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ trait Row extends Seq[Any] with Serializable {
6464
def getShort(i: Int): Short
6565
def getByte(i: Int): Byte
6666
def getString(i: Int): String
67+
def getAs[T](i: Int): T = apply(i).asInstanceOf[T]
6768

6869
override def toString() =
6970
s"[${this.mkString(",")}]"
@@ -118,6 +119,7 @@ object EmptyRow extends Row {
118119
def getShort(i: Int): Short = throw new UnsupportedOperationException
119120
def getByte(i: Int): Byte = throw new UnsupportedOperationException
120121
def getString(i: Int): String = throw new UnsupportedOperationException
122+
override def getAs[T](i: Int): T = throw new UnsupportedOperationException
121123

122124
def copy() = this
123125
}
@@ -217,19 +219,19 @@ class GenericMutableRow(size: Int) extends GenericRow(size) with MutableRow {
217219
/** No-arg constructor for serialization. */
218220
def this() = this(0)
219221

220-
override def setBoolean(ordinal: Int,value: Boolean): Unit = { values(ordinal) = value }
221-
override def setByte(ordinal: Int,value: Byte): Unit = { values(ordinal) = value }
222-
override def setDouble(ordinal: Int,value: Double): Unit = { values(ordinal) = value }
223-
override def setFloat(ordinal: Int,value: Float): Unit = { values(ordinal) = value }
224-
override def setInt(ordinal: Int,value: Int): Unit = { values(ordinal) = value }
225-
override def setLong(ordinal: Int,value: Long): Unit = { values(ordinal) = value }
226-
override def setString(ordinal: Int,value: String): Unit = { values(ordinal) = value }
222+
override def setBoolean(ordinal: Int, value: Boolean): Unit = { values(ordinal) = value }
223+
override def setByte(ordinal: Int, value: Byte): Unit = { values(ordinal) = value }
224+
override def setDouble(ordinal: Int, value: Double): Unit = { values(ordinal) = value }
225+
override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value }
226+
override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value }
227+
override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value }
228+
override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = value }
227229

228230
override def setNullAt(i: Int): Unit = { values(i) = null }
229231

230-
override def setShort(ordinal: Int,value: Short): Unit = { values(ordinal) = value }
232+
override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value }
231233

232-
override def update(ordinal: Int,value: Any): Unit = { values(ordinal) = value }
234+
override def update(ordinal: Int, value: Any): Unit = { values(ordinal) = value }
233235

234236
override def copy() = new GenericRow(values.clone())
235237
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificRow.scala renamed to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -233,9 +233,9 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
233233

234234
override def iterator: Iterator[Any] = values.map(_.boxed).iterator
235235

236-
def setString(ordinal: Int, value: String) = update(ordinal, value)
236+
override def setString(ordinal: Int, value: String) = update(ordinal, value)
237237

238-
def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]
238+
override def getString(ordinal: Int) = apply(ordinal).asInstanceOf[String]
239239

240240
override def setInt(ordinal: Int, value: Int): Unit = {
241241
val currentValue = values(ordinal).asInstanceOf[MutableInt]
@@ -306,4 +306,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
306306
override def getByte(i: Int): Byte = {
307307
values(i).asInstanceOf[MutableByte].value
308308
}
309+
310+
override def getAs[T](i: Int): T = {
311+
values(i).boxed.asInstanceOf[T]
312+
}
309313
}

0 commit comments

Comments
 (0)