Skip to content

Commit eeee512

Browse files
committed
Add converters for Null, Boolean, Byte, and Short columns.
1 parent 81f34f8 commit eeee512

File tree

3 files changed

+78
-21
lines changed

3 files changed

+78
-21
lines changed

sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,14 @@ public static int calculateBitSetWidthInBytes(int numFields) {
9494
settableFieldTypes = Collections.unmodifiableSet(
9595
new HashSet<DataType>(
9696
Arrays.asList(new DataType[] {
97-
IntegerType,
98-
LongType,
99-
DoubleType,
97+
NullType,
10098
BooleanType,
101-
ShortType,
10299
ByteType,
103-
FloatType
100+
ShortType,
101+
IntegerType,
102+
LongType,
103+
FloatType,
104+
DoubleType
104105
})));
105106

106107
// We support get() on a superset of the types for which we support set():

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ private object UnsafeColumnWriter {
110110

111111
def forType(dataType: DataType): UnsafeColumnWriter = {
112112
dataType match {
113+
case NullType => NullUnsafeColumnWriter
114+
case BooleanType => BooleanUnsafeColumnWriter
115+
case ByteType => ByteUnsafeColumnWriter
116+
case ShortType => ShortUnsafeColumnWriter
113117
case IntegerType => IntUnsafeColumnWriter
114118
case LongType => LongUnsafeColumnWriter
115119
case FloatType => FloatUnsafeColumnWriter
@@ -123,6 +127,10 @@ private object UnsafeColumnWriter {
123127

124128
// ------------------------------------------------------------------------------------------------
125129

130+
private object NullUnsafeColumnWriter extends NullUnsafeColumnWriter
131+
private object BooleanUnsafeColumnWriter extends BooleanUnsafeColumnWriter
132+
private object ByteUnsafeColumnWriter extends ByteUnsafeColumnWriter
133+
private object ShortUnsafeColumnWriter extends ShortUnsafeColumnWriter
126134
private object IntUnsafeColumnWriter extends IntUnsafeColumnWriter
127135
private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
128136
private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
@@ -134,6 +142,34 @@ private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
134142
def getSize(sourceRow: Row, column: Int): Int = 0
135143
}
136144

145+
private class NullUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
146+
override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
147+
target.setNullAt(column)
148+
0
149+
}
150+
}
151+
152+
private class BooleanUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
153+
override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
154+
target.setBoolean(column, source.getBoolean(column))
155+
0
156+
}
157+
}
158+
159+
private class ByteUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
160+
override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
161+
target.setByte(column, source.getByte(column))
162+
0
163+
}
164+
}
165+
166+
private class ShortUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
167+
override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
168+
target.setShort(column, source.getShort(column))
169+
0
170+
}
171+
}
172+
137173
private class IntUnsafeColumnWriter private() extends PrimitiveUnsafeColumnWriter {
138174
override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
139175
target.setInt(column, source.getInt(column))

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala

Lines changed: 36 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,20 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers {
7474
}
7575

7676
test("null handling") {
77-
val fieldTypes: Array[DataType] = Array(IntegerType, LongType, FloatType, DoubleType)
77+
val fieldTypes: Array[DataType] = Array(
78+
NullType,
79+
BooleanType,
80+
ByteType,
81+
ShortType,
82+
IntegerType,
83+
LongType,
84+
FloatType,
85+
DoubleType)
7886
val converter = new UnsafeRowConverter(fieldTypes)
7987

8088
val rowWithAllNullColumns: Row = {
8189
val r = new SpecificMutableRow(fieldTypes)
82-
for (i <- 0 to 3) {
90+
for (i <- 0 to fieldTypes.length - 1) {
8391
r.setNullAt(i)
8492
}
8593
r
@@ -94,23 +102,30 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers {
94102
val createdFromNull = new UnsafeRow()
95103
createdFromNull.pointTo(
96104
createdFromNullBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
97-
for (i <- 0 to 3) {
105+
for (i <- 0 to fieldTypes.length - 1) {
98106
assert(createdFromNull.isNullAt(i))
99107
}
100-
createdFromNull.getInt(0) should be (0)
101-
createdFromNull.getLong(1) should be (0)
102-
assert(java.lang.Float.isNaN(createdFromNull.getFloat(2)))
103-
assert(java.lang.Double.isNaN(createdFromNull.getFloat(3)))
108+
createdFromNull.getBoolean(1) should be (false)
109+
createdFromNull.getByte(2) should be (0)
110+
createdFromNull.getShort(3) should be (0)
111+
createdFromNull.getInt(4) should be (0)
112+
createdFromNull.getLong(5) should be (0)
113+
assert(java.lang.Float.isNaN(createdFromNull.getFloat(6)))
114+
assert(java.lang.Double.isNaN(createdFromNull.getFloat(7)))
104115

105116
// If we have an UnsafeRow with columns that are initially non-null and we null out those
106117
// columns, then the serialized row representation should be identical to what we would get by
107118
// creating an entirely null row via the converter
108119
val rowWithNoNullColumns: Row = {
109120
val r = new SpecificMutableRow(fieldTypes)
110-
r.setInt(0, 100)
111-
r.setLong(1, 200)
112-
r.setFloat(2, 300)
113-
r.setDouble(3, 400)
121+
r.setNullAt(0)
122+
r.setBoolean(1, false)
123+
r.setByte(2, 20)
124+
r.setShort(3, 30)
125+
r.setInt(4, 400)
126+
r.setLong(5, 500)
127+
r.setFloat(6, 600)
128+
r.setDouble(7, 700)
114129
r
115130
}
116131
val setToNullAfterCreationBuffer: Array[Long] = new Array[Long](sizeRequired / 8)
@@ -119,12 +134,17 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers {
119134
val setToNullAfterCreation = new UnsafeRow()
120135
setToNullAfterCreation.pointTo(
121136
setToNullAfterCreationBuffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
122-
setToNullAfterCreation.getInt(0) should be (rowWithNoNullColumns.getInt(0))
123-
setToNullAfterCreation.getLong(1) should be (rowWithNoNullColumns.getLong(1))
124-
setToNullAfterCreation.getFloat(2) should be (rowWithNoNullColumns.getFloat(2))
125-
setToNullAfterCreation.getDouble(3) should be (rowWithNoNullColumns.getDouble(3))
126137

127-
for (i <- 0 to 3) {
138+
setToNullAfterCreation.isNullAt(0) should be (rowWithNoNullColumns.isNullAt(0))
139+
setToNullAfterCreation.getBoolean(1) should be (rowWithNoNullColumns.getBoolean(1))
140+
setToNullAfterCreation.getByte(2) should be (rowWithNoNullColumns.getByte(2))
141+
setToNullAfterCreation.getShort(3) should be (rowWithNoNullColumns.getShort(3))
142+
setToNullAfterCreation.getInt(4) should be (rowWithNoNullColumns.getInt(4))
143+
setToNullAfterCreation.getLong(5) should be (rowWithNoNullColumns.getLong(5))
144+
setToNullAfterCreation.getFloat(6) should be (rowWithNoNullColumns.getFloat(6))
145+
setToNullAfterCreation.getDouble(7) should be (rowWithNoNullColumns.getDouble(7))
146+
147+
for (i <- 0 to fieldTypes.length - 1) {
128148
setToNullAfterCreation.setNullAt(i)
129149
}
130150
assert(Arrays.equals(createdFromNullBuffer, setToNullAfterCreationBuffer))

0 commit comments

Comments
 (0)