Skip to content

Commit 1ca65b4

Browse files
committed
[SPARK-32167][SQL] Fix GetArrayStructFields to respect inner field's nullability together
Fix nullability of `GetArrayStructFields`. It should consider both the original array's `containsNull` and the inner field's nullability. Fix a correctness issue. Yes. See the added test. a new UT and end-to-end test Closes #28992 from cloud-fan/bug. Authored-by: Wenchen Fan <[email protected]> Signed-off-by: Dongjoon Hyun <[email protected]> (cherry picked from commit 5d296ed) Signed-off-by: Dongjoon Hyun <[email protected]>
1 parent 2227a16 commit 1ca65b4

File tree

4 files changed

+41
-3
lines changed

4 files changed

+41
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ object ExtractValue {
5757
val fieldName = v.toString
5858
val ordinal = findField(fields, fieldName, resolver)
5959
GetArrayStructFields(child, fields(ordinal).copy(name = fieldName),
60-
ordinal, fields.length, containsNull)
60+
ordinal, fields.length, containsNull || fields(ordinal).nullable)
6161

6262
case (_: ArrayType, _) => GetArrayItem(child, extraction)
6363

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.SparkFunSuite
21+
import org.apache.spark.sql.Row
2122
import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue
2223
import org.apache.spark.sql.catalyst.dsl.expressions._
2324
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
25+
import org.apache.spark.sql.internal.SQLConf
2426
import org.apache.spark.sql.types._
2527
import org.apache.spark.unsafe.types.UTF8String
2628

@@ -125,6 +127,31 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
125127
checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null)
126128
}
127129

130+
test("SPARK-32167: nullability of GetArrayStructFields") {
131+
val resolver = SQLConf.get.resolver
132+
133+
val array1 = ArrayType(
134+
new StructType().add("a", "int", nullable = true),
135+
containsNull = false)
136+
val data1 = Literal.create(Seq(Row(null)), array1)
137+
val get1 = ExtractValue(data1, Literal("a"), resolver).asInstanceOf[GetArrayStructFields]
138+
assert(get1.containsNull)
139+
140+
val array2 = ArrayType(
141+
new StructType().add("a", "int", nullable = false),
142+
containsNull = true)
143+
val data2 = Literal.create(Seq(null), array2)
144+
val get2 = ExtractValue(data2, Literal("a"), resolver).asInstanceOf[GetArrayStructFields]
145+
assert(get2.containsNull)
146+
147+
val array3 = ArrayType(
148+
new StructType().add("a", "int", nullable = false),
149+
containsNull = false)
150+
val data3 = Literal.create(Seq(Row(1)), array3)
151+
val get3 = ExtractValue(data3, Literal("a"), resolver).asInstanceOf[GetArrayStructFields]
152+
assert(!get3.containsNull)
153+
}
154+
128155
test("CreateArray") {
129156
val intSeq = Seq(5, 10, 15, 20, 25)
130157
val longSeq = intSeq.map(_.toLong)

sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@
1717

1818
package org.apache.spark.sql
1919

20+
import scala.collection.JavaConverters._
21+
2022
import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct
2123
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2224
import org.apache.spark.sql.test.SharedSQLContext
25+
import org.apache.spark.sql.types.{ArrayType, StructType}
2326

2427
class ComplexTypesSuite extends QueryTest with SharedSQLContext {
28+
import testImplicits._
2529

2630
override def beforeAll() {
2731
super.beforeAll()
@@ -106,4 +110,11 @@ class ComplexTypesSuite extends QueryTest with SharedSQLContext {
106110
checkAnswer(df1, Row(10, 12) :: Row(11, 13) :: Nil)
107111
checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0)
108112
}
113+
114+
test("SPARK-32167: get field from an array of struct") {
115+
val innerStruct = new StructType().add("i", "int", nullable = true)
116+
val schema = new StructType().add("arr", ArrayType(innerStruct, containsNull = false))
117+
val df = spark.createDataFrame(List(Row(Seq(Row(1), Row(null)))).asJava, schema)
118+
checkAnswer(df.select($"arr".getField("i")), Row(Seq(1, null)))
119+
}
109120
}

sql/core/src/test/scala/org/apache/spark/sql/execution/SelectedFieldSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,13 @@ class SelectedFieldSuite extends SparkFunSuite with BeforeAndAfterAll {
256256
StructField("col3", ArrayType(StructType(
257257
StructField("field1", StructType(
258258
StructField("subfield1", IntegerType, nullable = false) :: Nil))
259-
:: Nil), containsNull = false), nullable = false)
259+
:: Nil), containsNull = true), nullable = false)
260260
}
261261

262262
testSelect(arrayWithStructAndMap, "col3.field2['foo'] as foo") {
263263
StructField("col3", ArrayType(StructType(
264264
StructField("field2", MapType(StringType, IntegerType, valueContainsNull = false))
265-
:: Nil), containsNull = false), nullable = false)
265+
:: Nil), containsNull = true), nullable = false)
266266
}
267267

268268
// |-- col1: string (nullable = false)

0 commit comments

Comments
 (0)