Skip to content

Commit 7bae2c1

Browse files
committed
Non-nullable null type should not coerce to nullable type
1 parent bd32400 commit 7bae2c1

File tree

5 files changed

+54
-8
lines changed

5 files changed

+54
-8
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ object TypeCoercion {
160160
}
161161
case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) =>
162162
findTypeFunc(kt1, kt2)
163-
.filter { kt => Cast.canCastMapKeyNullSafe(kt1, kt) && Cast.canCastMapKeyNullSafe(kt2, kt) }
163+
.filter { kt => !Cast.forceNullable(kt1, kt) && !Cast.forceNullable(kt2, kt) }
164164
.flatMap { kt =>
165165
findTypeFunc(vt1, vt2).map { vt =>
166166
MapType(kt, vt, valueContainsNull1 || valueContainsNull2 ||

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@ object Cast {
7777
resolvableNullability(fn || forceNullable(fromType, toType), tn)
7878

7979
case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
80-
canCast(fromKey, toKey) && canCastMapKeyNullSafe(fromKey, toKey) &&
80+
canCast(fromKey, toKey) &&
81+
(!forceNullable(fromKey, toKey)) &&
8182
canCast(fromValue, toValue) &&
8283
resolvableNullability(fn || forceNullable(fromValue, toValue), tn)
8384

@@ -97,11 +98,6 @@ object Cast {
9798
case _ => false
9899
}
99100

100-
def canCastMapKeyNullSafe(fromType: DataType, toType: DataType): Boolean = {
101-
// If the original map key type is NullType, it's OK as the map must be empty.
102-
fromType == NullType || !forceNullable(fromType, toType)
103-
}
104-
105101
/**
106102
* Return true if we need to use the `timeZone` information casting `from` type to `to` type.
107103
* The patterns matched reflect the current implementation in the Cast node.
@@ -210,8 +206,15 @@ object Cast {
210206
case _ => false // overflow
211207
}
212208

209+
/**
210+
* Returns `true` iff it should change the nullability of this type in map type,
211+
* array type, expressions such as cast, etc.
212+
*
213+
* Note that you should take the nullability context into account.
214+
* For example, it should not force to nullable type when null type in array
215+
* type is non-nullable which means an empty array of null type.
216+
*/
213217
def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match {
214-
case (NullType, _) => true
215218
case (_, _) if from == to => false
216219

217220
case (StringType, BinaryType) => false

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,23 @@ class TypeCoercionSuite extends AnalysisTest {
497497
.add("null", IntegerType, nullable = false),
498498
Some(new StructType()
499499
.add("null", IntegerType, nullable = true)))
500+
501+
widenTest(
502+
ArrayType(NullType, containsNull = false),
503+
ArrayType(IntegerType, containsNull = false),
504+
Some(ArrayType(IntegerType, containsNull = false)))
505+
506+
widenTest(MapType(NullType, NullType, false),
507+
MapType(IntegerType, StringType, false),
508+
Some(MapType(IntegerType, StringType, false)))
509+
510+
widenTest(
511+
new StructType()
512+
.add("null", NullType, nullable = false),
513+
new StructType()
514+
.add("null", IntegerType, nullable = false),
515+
Some(new StructType()
516+
.add("null", IntegerType, nullable = false)))
500517
}
501518

502519
test("wider common type for decimal and array") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,14 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper {
412412
assert(ret.resolved)
413413
checkEvaluation(ret, Seq(null, true, false, null))
414414
}
415+
416+
{
417+
val array = Literal.create(Seq.empty, ArrayType(NullType, containsNull = false))
418+
val ret = cast(array, ArrayType(IntegerType, containsNull = false))
419+
assert(ret.resolved)
420+
checkEvaluation(ret, Seq.empty)
421+
}
422+
415423
{
416424
val ret = cast(array, ArrayType(BooleanType, containsNull = false))
417425
assert(ret.resolved === false)
@@ -1157,6 +1165,17 @@ class CastSuite extends CastSuiteBase {
11571165
StructType(StructField("a", IntegerType, true) :: Nil)))
11581166
}
11591167

1168+
test("SPARK-31227: Non-nullable null type should not coerce to nullable type") {
1169+
assert(Cast.canCast(ArrayType(NullType, false), ArrayType(IntegerType, false)))
1170+
1171+
assert(Cast.canCast(
1172+
MapType(NullType, NullType, false), MapType(IntegerType, IntegerType, false)))
1173+
1174+
assert(Cast.canCast(
1175+
StructType(StructField("a", NullType, false) :: Nil),
1176+
StructType(StructField("a", IntegerType, false) :: Nil)))
1177+
}
1178+
11601179
test("Cast should output null for invalid strings when ANSI is not enabled.") {
11611180
withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") {
11621181
checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null)

sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,6 +1532,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
15321532
assert(e.getMessage.contains("string, binary or array"))
15331533
}
15341534

1535+
test("SPARK-31227: Non-nullable null type should not coerce to nullable type in concat") {
1536+
val actual = spark.range(1).selectExpr("concat(array(), array(1)) as arr")
1537+
val expected = spark.range(1).selectExpr("array(1) as arr")
1538+
checkAnswer(actual, expected)
1539+
assert(actual.schema === expected.schema)
1540+
}
1541+
15351542
test("flatten function") {
15361543
// Test cases with a primitive type
15371544
val intDF = Seq(

0 commit comments

Comments
 (0)