Skip to content

Commit 1bc36cc

Browse files
committed
Refactor UnsafeRowConverter to avoid unnecessary boxing.
We now pass the source row into the method, allowing the converter to use type specific accessors to extract column values.
1 parent 017b2dc commit 1bc36cc

File tree

2 files changed

+38
-74
lines changed

2 files changed

+38
-74
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,9 @@ public final class UnsafeRow implements MutableRow {
5454
private Object baseObject;
5555
private long baseOffset;
5656

57+
Object getBaseObject() { return baseObject; }
58+
long getBaseOffset() { return baseOffset; }
59+
5760
/** The number of fields in this row, used for calculating the bitset width (and in assertions) */
5861
private int numFields;
5962

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala

Lines changed: 35 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
3636
private[this] val unsafeRow = new UnsafeRow()
3737

3838
/** Functions for encoding each column */
39-
private[this] val writers: Array[UnsafeColumnWriter[Any]] = {
40-
fieldTypes.map(t => UnsafeColumnWriter.forType(t).asInstanceOf[UnsafeColumnWriter[Any]])
39+
private[this] val writers: Array[UnsafeColumnWriter] = {
40+
fieldTypes.map(t => UnsafeColumnWriter.forType(t))
4141
}
4242

4343
/** The size, in bytes, of the fixed-length portion of the row, including the null bitmap */
@@ -52,7 +52,7 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
5252
var variableLengthFieldSize: Int = 0
5353
while (fieldNumber < writers.length) {
5454
if (!row.isNullAt(fieldNumber)) {
55-
variableLengthFieldSize += writers(fieldNumber).getSize(row(fieldNumber))
55+
variableLengthFieldSize += writers(fieldNumber).getSize(row, fieldNumber)
5656
}
5757
fieldNumber += 1
5858
}
@@ -75,13 +75,7 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
7575
if (row.isNullAt(fieldNumber)) {
7676
unsafeRow.setNullAt(fieldNumber)
7777
} else {
78-
appendCursor += writers(fieldNumber).write(
79-
row(fieldNumber),
80-
fieldNumber,
81-
unsafeRow,
82-
baseObject,
83-
baseOffset,
84-
appendCursor)
78+
appendCursor += writers(fieldNumber).write(row, unsafeRow, fieldNumber, appendCursor)
8579
}
8680
fieldNumber += 1
8781
}
@@ -93,36 +87,28 @@ class UnsafeRowConverter(fieldTypes: Array[DataType]) {
9387
/**
9488
* Function for writing a column into an UnsafeRow.
9589
*/
96-
private abstract class UnsafeColumnWriter[T] {
90+
private abstract class UnsafeColumnWriter {
9791
/**
9892
* Write a value into an UnsafeRow.
9993
*
100-
* @param value the value to write
101-
* @param columnNumber what column to write it to
102-
* @param row a pointer to the unsafe row
103-
* @param baseObject the base object of the target row's address
104-
* @param baseOffset the base offset of the target row's address
94+
* @param source the row being converted
95+
* @param target a pointer to the converted unsafe row
96+
* @param column the column to write
10597
* @param appendCursor the offset from the start of the unsafe row to the end of the row;
10698
* used for calculating where variable-length data should be written
10799
* @return the number of variable-length bytes written
108100
*/
109-
def write(
110-
value: T,
111-
columnNumber: Int,
112-
row: UnsafeRow,
113-
baseObject: Object,
114-
baseOffset: Long,
115-
appendCursor: Int): Int
101+
def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int
116102

117103
/**
118104
* Return the number of bytes that are needed to write this variable-length value.
119105
*/
120-
def getSize(value: T): Int
106+
def getSize(source: Row, column: Int): Int
121107
}
122108

123109
private object UnsafeColumnWriter {
124110

125-
def forType(dataType: DataType): UnsafeColumnWriter[_] = {
111+
def forType(dataType: DataType): UnsafeColumnWriter = {
126112
dataType match {
127113
case IntegerType => IntUnsafeColumnWriter
128114
case LongType => LongUnsafeColumnWriter
@@ -143,74 +129,49 @@ private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
143129
private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
144130
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
145131

146-
private abstract class PrimitiveUnsafeColumnWriter[T] extends UnsafeColumnWriter[T] {
147-
def getSize(value: T): Int = 0
132+
private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
133+
// Primitives don't write to the variable-length region:
134+
def getSize(sourceRow: Row, column: Int): Int = 0
148135
}
149136

150-
private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Int] {
151-
override def write(
152-
value: Int,
153-
columnNumber: Int,
154-
row: UnsafeRow,
155-
baseObject: Object,
156-
baseOffset: Long,
157-
appendCursor: Int): Int = {
158-
row.setInt(columnNumber, value)
137+
private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
138+
override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
139+
target.setInt(column, source.getInt(column))
159140
0
160141
}
161142
}
162143

163-
private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Long] {
164-
override def write(
165-
value: Long,
166-
columnNumber: Int,
167-
row: UnsafeRow,
168-
baseObject: Object,
169-
baseOffset: Long,
170-
appendCursor: Int): Int = {
171-
row.setLong(columnNumber, value)
144+
private class LongUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
145+
override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
146+
target.setLong(column, source.getLong(column))
172147
0
173148
}
174149
}
175150

176-
private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Float] {
177-
override def write(
178-
value: Float,
179-
columnNumber: Int,
180-
row: UnsafeRow,
181-
baseObject: Object,
182-
baseOffset: Long,
183-
appendCursor: Int): Int = {
184-
row.setFloat(columnNumber, value)
151+
private class FloatUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
152+
override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
153+
target.setFloat(column, source.getFloat(column))
185154
0
186155
}
187156
}
188157

189-
private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter[Double] {
190-
override def write(
191-
value: Double,
192-
columnNumber: Int,
193-
row: UnsafeRow,
194-
baseObject: Object,
195-
baseOffset: Long,
196-
appendCursor: Int): Int = {
197-
row.setDouble(columnNumber, value)
158+
private class DoubleUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
159+
override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
160+
target.setDouble(column, source.getDouble(column))
198161
0
199162
}
200163
}
201164

202-
private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter[UTF8String] {
203-
def getSize(value: UTF8String): Int = {
204-
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(value.getBytes.length)
165+
private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter {
166+
def getSize(source: Row, column: Int): Int = {
167+
val numBytes = source.get(column).asInstanceOf[UTF8String].getBytes.length
168+
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
205169
}
206170

207-
override def write(
208-
value: UTF8String,
209-
columnNumber: Int,
210-
row: UnsafeRow,
211-
baseObject: Object,
212-
baseOffset: Long,
213-
appendCursor: Int): Int = {
171+
override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
172+
val value = source.get(column).asInstanceOf[UTF8String]
173+
val baseObject = target.getBaseObject
174+
val baseOffset = target.getBaseOffset
214175
val numBytes = value.getBytes.length
215176
PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, numBytes)
216177
PlatformDependent.copyMemory(
@@ -220,7 +181,7 @@ private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter[UTF8
220181
baseOffset + appendCursor + 8,
221182
numBytes
222183
)
223-
row.setLong(columnNumber, appendCursor)
184+
target.setLong(column, appendCursor)
224185
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
225186
}
226187
}

0 commit comments

Comments
 (0)