Skip to content

Commit 3875e1f

Browse files
committed
eliminate nullcheck code if all of the elements do not have null
add unit tests
1 parent 688b6ef commit 3875e1f

File tree

2 files changed

+62
-16
lines changed

2 files changed

+62
-16
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -117,22 +117,22 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
117117
$rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
118118
"""
119119

120-
case a @ ArrayType(et, _) =>
120+
case a @ ArrayType(et, cn) =>
121121
s"""
122122
// Remember the current cursor so that we can calculate how many bytes are
123123
// written later.
124124
final int $tmpCursor = $bufferHolder.cursor;
125-
${writeArrayToBuffer(ctx, input.value, et, bufferHolder)}
125+
${writeArrayToBuffer(ctx, input.value, et, cn, bufferHolder)}
126126
$rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
127127
$rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
128128
"""
129129

130-
case m @ MapType(kt, vt, _) =>
130+
case m @ MapType(kt, vt, cn) =>
131131
s"""
132132
// Remember the current cursor so that we can calculate how many bytes are
133133
// written later.
134134
final int $tmpCursor = $bufferHolder.cursor;
135-
${writeMapToBuffer(ctx, input.value, kt, vt, bufferHolder)}
135+
${writeMapToBuffer(ctx, input.value, kt, vt, cn, bufferHolder)}
136136
$rowWriter.setOffsetAndSize($index, $tmpCursor, $bufferHolder.cursor - $tmpCursor);
137137
$rowWriter.alignToWords($bufferHolder.cursor - $tmpCursor);
138138
"""
@@ -173,6 +173,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
173173
ctx: CodegenContext,
174174
input: String,
175175
elementType: DataType,
176+
containsNull: Boolean,
176177
bufferHolder: String): String = {
177178
val arrayWriterClass = classOf[UnsafeArrayWriter].getName
178179
val arrayWriter = ctx.freshName("arrayWriter")
@@ -202,16 +203,16 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
202203
${writeStructToBuffer(ctx, element, t.map(_.dataType), bufferHolder)}
203204
"""
204205

205-
case a @ ArrayType(et, _) =>
206+
case a @ ArrayType(et, cn) =>
206207
s"""
207208
$arrayWriter.setOffset($index);
208-
${writeArrayToBuffer(ctx, element, et, bufferHolder)}
209+
${writeArrayToBuffer(ctx, element, et, cn, bufferHolder)}
209210
"""
210211

211-
case m @ MapType(kt, vt, _) =>
212+
case m @ MapType(kt, vt, cn) =>
212213
s"""
213214
$arrayWriter.setOffset($index);
214-
${writeMapToBuffer(ctx, element, kt, vt, bufferHolder)}
215+
${writeMapToBuffer(ctx, element, kt, vt, cn, bufferHolder)}
215216
"""
216217

217218
case t: DecimalType =>
@@ -222,6 +223,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
222223
case _ => s"$arrayWriter.write($index, $element);"
223224
}
224225

226+
val storeElement = if (containsNull) {
227+
s"""
228+
if ($input.isNullAt($index)) {
229+
$arrayWriter.setNullAt($index);
230+
} else {
231+
final $jt $element = ${ctx.getValue(input, et, index)};
232+
$writeElement
233+
}
234+
"""
235+
} else {
236+
s"""
237+
final $jt $element = ${ctx.getValue(input, et, index)};
238+
$writeElement
239+
"""
240+
}
225241
s"""
226242
if ($input instanceof UnsafeArrayData) {
227243
${writeUnsafeData(ctx, s"((UnsafeArrayData) $input)", bufferHolder)}
@@ -230,12 +246,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
230246
$arrayWriter.initialize($bufferHolder, $numElements, $fixedElementSize);
231247

232248
for (int $index = 0; $index < $numElements; $index++) {
233-
if ($input.isNullAt($index)) {
234-
$arrayWriter.setNullAt($index);
235-
} else {
236-
final $jt $element = ${ctx.getValue(input, et, index)};
237-
$writeElement
238-
}
249+
$storeElement
239250
}
240251
}
241252
"""
@@ -247,6 +258,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
247258
input: String,
248259
keyType: DataType,
249260
valueType: DataType,
261+
valueContainsNull: Boolean,
250262
bufferHolder: String): String = {
251263
val keys = ctx.freshName("keys")
252264
val values = ctx.freshName("values")
@@ -268,11 +280,11 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro
268280
// Remember the current cursor so that we can write numBytes of key array later.
269281
final int $tmpCursor = $bufferHolder.cursor;
270282

271-
${writeArrayToBuffer(ctx, keys, keyType, bufferHolder)}
283+
${writeArrayToBuffer(ctx, keys, keyType, false, bufferHolder)}
272284
// Write the numBytes of key array into the first 4 bytes.
273285
Platform.putInt($bufferHolder.buffer, $tmpCursor - 4, $bufferHolder.cursor - $tmpCursor);
274286

275-
${writeArrayToBuffer(ctx, values, valueType, bufferHolder)}
287+
${writeArrayToBuffer(ctx, values, valueType, valueContainsNull, bufferHolder)}
276288
}
277289
"""
278290
}

sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,40 @@ import org.apache.spark.sql.test.SharedSQLContext
2626
class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext {
2727
import testImplicits._
2828

29+
test("primitive type on array") {
30+
val rows = sparkContext.parallelize(Seq(1, 2), 1).toDF("v").
31+
selectExpr("Array(v + 2, v + 3)").collect
32+
QueryTest.sameRows(Seq(Row(Array(3, 4)), Row(Array(4, 5))), rows.toSeq)
33+
}
34+
35+
test("primitive type and null on array") {
36+
val rows = sparkContext.parallelize(Seq(1, 2), 1).toDF("v").
37+
selectExpr("Array(v + 2, null, v + 3)").collect
38+
QueryTest.sameRows(Seq(Row(Array(3, null, 4)), Row(Array(4, null, 5))), rows.toSeq)
39+
}
40+
41+
test("array with null on array") {
42+
val rows = sparkContext.parallelize(Seq(1, 2), 1).toDF("v").
43+
selectExpr("Array(Array(v, v + 1)," +
44+
"null," +
45+
"Array(v, v - 1))").collect
46+
QueryTest.sameRows(Seq(
47+
Row(Array(Array(1, 2), null, Array(3, 4))),
48+
Row(Array(Array(2, 3), null, Array(4, 5)))), rows.toSeq)
49+
}
50+
51+
test("primitive type on map") {
52+
val rows = sparkContext.parallelize(Seq(1, 2), 1).toDF("v").
53+
selectExpr("map(v + 3, v + 4)").collect
54+
QueryTest.sameRows(Seq(Row(Map(4 -> 5)), Row(Map(5 -> 6))), rows.toSeq)
55+
}
56+
57+
test("map with null value on map") {
58+
val rows = sparkContext.parallelize(Seq(1, 2), 1).toDF("v").
59+
selectExpr("map(v, null)").collect
60+
QueryTest.sameRows(Seq(Row(Map(1 -> null)), Row(Map(2 -> null))), rows.toSeq)
61+
}
62+
2963
test("UDF on struct") {
3064
val f = udf((a: String) => a)
3165
val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b")

0 commit comments

Comments
 (0)