From d22707aab9df2b41e30bf381dfe8c4387614152a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 6 Jul 2020 20:07:33 -0700 Subject: [PATCH] [SPARK-32167][SQL] Fix GetArrayStructFields to respect inner field's nullability together Fix nullability of `GetArrayStructFields`. It should consider both the original array's `containsNull` and the inner field's nullability. Fix a correctness issue. Yes. See the added test. a new UT and end-to-end test Closes #28992 from cloud-fan/bug. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun (cherry picked from commit 5d296ed39e3dd79ddb10c68657e773adba40a5e0) Signed-off-by: Dongjoon Hyun --- .../expressions/complexTypeExtractors.scala | 2 +- .../expressions/ComplexTypeSuite.scala | 27 +++++++++++++++++++ .../apache/spark/sql/ComplexTypesSuite.scala | 11 ++++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 8994eeff92c7..a3dd983d0591 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -57,7 +57,7 @@ object ExtractValue { val fieldName = v.toString val ordinal = findField(fields, fieldName, resolver) GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), - ordinal, fields.length, containsNull) + ordinal, fields.length, containsNull || fields(ordinal).nullable) case (_: ArrayType, _) => GetArrayItem(child, extraction) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 77aaf55480ec..c50191d5bdd4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -125,6 +127,31 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(getArrayStructFields(nullArrayStruct, "a"), null) } + test("SPARK-32167: nullability of GetArrayStructFields") { + val resolver = SQLConf.get.resolver + + val array1 = ArrayType( + new StructType().add("a", "int", nullable = true), + containsNull = false) + val data1 = Literal.create(Seq(Row(null)), array1) + val get1 = ExtractValue(data1, Literal("a"), resolver).asInstanceOf[GetArrayStructFields] + assert(get1.containsNull) + + val array2 = ArrayType( + new StructType().add("a", "int", nullable = false), + containsNull = true) + val data2 = Literal.create(Seq(null), array2) + val get2 = ExtractValue(data2, Literal("a"), resolver).asInstanceOf[GetArrayStructFields] + assert(get2.containsNull) + + val array3 = ArrayType( + new StructType().add("a", "int", nullable = false), + containsNull = false) + val data3 = Literal.create(Seq(Row(1)), array3) + val get3 = ExtractValue(data3, Literal("a"), resolver).asInstanceOf[GetArrayStructFields] + assert(!get3.containsNull) + } + test("CreateArray") { val intSeq = Seq(5, 10, 15, 20, 25) val longSeq = intSeq.map(_.toLong) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala index b74fe2f90df2..30ce1c73a0c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ComplexTypesSuite.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql +import scala.collection.JavaConverters._ + import org.apache.spark.sql.catalyst.expressions.CreateNamedStruct import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{ArrayType, StructType} class ComplexTypesSuite extends QueryTest with SharedSQLContext { + import testImplicits._ override def beforeAll() { super.beforeAll() @@ -106,4 +110,11 @@ class ComplexTypesSuite extends QueryTest with SharedSQLContext { checkAnswer(df1, Row(10, 12) :: Row(11, 13) :: Nil) checkNamedStruct(df.queryExecution.optimizedPlan, expectedCount = 0) } + + test("SPARK-32167: get field from an array of struct") { + val innerStruct = new StructType().add("i", "int", nullable = true) + val schema = new StructType().add("arr", ArrayType(innerStruct, containsNull = false)) + val df = spark.createDataFrame(List(Row(Seq(Row(1), Row(null)))).asJava, schema) + checkAnswer(df.select($"arr".getField("i")), Row(Seq(1, null))) + } }