diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 5a5d7c6620fb5..eb9a4d4feb783 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -160,7 +160,7 @@ object TypeCoercion { } case (MapType(kt1, vt1, valueContainsNull1), MapType(kt2, vt2, valueContainsNull2)) => findTypeFunc(kt1, kt2) - .filter { kt => Cast.canCastMapKeyNullSafe(kt1, kt) && Cast.canCastMapKeyNullSafe(kt2, kt) } + .filter { kt => !Cast.forceNullable(kt1, kt) && !Cast.forceNullable(kt2, kt) } .flatMap { kt => findTypeFunc(vt1, vt2).map { vt => MapType(kt, vt, valueContainsNull1 || valueContainsNull2 || diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 8177136edfd62..8d82956cc6f74 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -77,7 +77,8 @@ object Cast { resolvableNullability(fn || forceNullable(fromType, toType), tn) case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) => - canCast(fromKey, toKey) && canCastMapKeyNullSafe(fromKey, toKey) && + canCast(fromKey, toKey) && + (!forceNullable(fromKey, toKey)) && canCast(fromValue, toValue) && resolvableNullability(fn || forceNullable(fromValue, toValue), tn) @@ -97,11 +98,6 @@ object Cast { case _ => false } - def canCastMapKeyNullSafe(fromType: DataType, toType: DataType): Boolean = { - // If the original map key type is NullType, it's OK as the map must be empty. - fromType == NullType || !forceNullable(fromType, toType) - } - /** * Return true if we need to use the `timeZone` information casting `from` type to `to` type. * The patterns matched reflect the current implementation in the Cast node. @@ -210,8 +206,13 @@ object Cast { case _ => false // overflow } + /** + * Returns `true` if casting non-nullable values from `from` type to `to` type + * may return null. Note that the caller side should take care of input nullability + * first and only call this method if the input is not nullable. + */ def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match { - case (NullType, _) => true + case (NullType, _) => false // empty array or map case case (_, _) if from == to => false case (StringType, BinaryType) => false @@ -269,7 +270,7 @@ abstract class CastBase extends UnaryExpression with TimeZoneAwareExpression wit } } - override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable + override def nullable: Boolean = child.nullable || Cast.forceNullable(child.dataType, dataType) protected def ansiEnabled: Boolean diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 95005fd3f5a5e..ab21a9ea5ba18 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval class TypeCoercionSuite extends AnalysisTest { + import TypeCoercionSuite._ // scalastyle:off line.size.limit // The following table shows all implicit data type conversions that are not visible to the user. @@ -99,22 +100,6 @@ class TypeCoercionSuite extends AnalysisTest { case _ => Literal.create(null, dataType) } - val integralTypes: Seq[DataType] = - Seq(ByteType, ShortType, IntegerType, LongType) - val fractionalTypes: Seq[DataType] = - Seq(DoubleType, FloatType, DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2)) - val numericTypes: Seq[DataType] = integralTypes ++ fractionalTypes - val atomicTypes: Seq[DataType] = - numericTypes ++ Seq(BinaryType, BooleanType, StringType, DateType, TimestampType) - val complexTypes: Seq[DataType] = - Seq(ArrayType(IntegerType), - ArrayType(StringType), - MapType(StringType, StringType), - new StructType().add("a1", StringType), - new StructType().add("a1", StringType).add("a2", IntegerType)) - val allTypes: Seq[DataType] = - atomicTypes ++ complexTypes ++ Seq(NullType, CalendarIntervalType) - // Check whether the type `checkedType` can be cast to all the types in `castableTypes`, // but cannot be cast to the other types in `allTypes`. private def checkTypeCasting(checkedType: DataType, castableTypes: Seq[DataType]): Unit = { @@ -497,6 +482,23 @@ class TypeCoercionSuite extends AnalysisTest { .add("null", IntegerType, nullable = false), Some(new StructType() .add("null", IntegerType, nullable = true))) + + widenTest( + ArrayType(NullType, containsNull = false), + ArrayType(IntegerType, containsNull = false), + Some(ArrayType(IntegerType, containsNull = false))) + + widenTest(MapType(NullType, NullType, false), + MapType(IntegerType, StringType, false), + Some(MapType(IntegerType, StringType, false))) + + widenTest( + new StructType() + .add("null", NullType, nullable = false), + new StructType() + .add("null", IntegerType, nullable = false), + Some(new StructType() + .add("null", IntegerType, nullable = false))) } test("wider common type for decimal and array") { @@ -728,8 +730,6 @@ class TypeCoercionSuite extends AnalysisTest { } test("cast NullType for expressions that implement ExpectsInputTypes") { - import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts, AnyTypeUnaryExpression(Literal.create(null, NullType)), AnyTypeUnaryExpression(Literal.create(null, NullType))) @@ -740,8 +740,6 @@ class TypeCoercionSuite extends AnalysisTest { } test("cast NullType for binary operators") { - import TypeCoercionSuite._ - ruleTest(TypeCoercion.ImplicitTypeCasts, AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType)), AnyTypeBinaryOperator(Literal.create(null, NullType), Literal.create(null, NullType))) @@ -1548,6 +1546,22 @@ class TypeCoercionSuite extends AnalysisTest { object TypeCoercionSuite { + val integralTypes: Seq[DataType] = + Seq(ByteType, ShortType, IntegerType, LongType) + val fractionalTypes: Seq[DataType] = + Seq(DoubleType, FloatType, DecimalType.SYSTEM_DEFAULT, DecimalType(10, 2)) + val numericTypes: Seq[DataType] = integralTypes ++ fractionalTypes + val atomicTypes: Seq[DataType] = + numericTypes ++ Seq(BinaryType, BooleanType, StringType, DateType, TimestampType) + val complexTypes: Seq[DataType] = + Seq(ArrayType(IntegerType), + ArrayType(StringType), + MapType(StringType, StringType), + new StructType().add("a1", StringType), + new StructType().add("a1", StringType).add("a2", IntegerType)) + val allTypes: Seq[DataType] = + atomicTypes ++ complexTypes ++ Seq(NullType, CalendarIntervalType) + case class AnyTypeUnaryExpression(child: Expression) extends UnaryExpression with ExpectsInputTypes with Unevaluable { override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 8296562ac739a..7083058a906e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCoercion.numericPrecedence +import org.apache.spark.sql.catalyst.analysis.TypeCoercionSuite import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectList, CollectSet} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.util.DateTimeConstants._ @@ -412,6 +413,14 @@ abstract class CastSuiteBase extends SparkFunSuite with ExpressionEvalHelper { assert(ret.resolved) checkEvaluation(ret, Seq(null, true, false, null)) } + + { + val array = Literal.create(Seq.empty, ArrayType(NullType, containsNull = false)) + val ret = cast(array, ArrayType(IntegerType, containsNull = false)) + assert(ret.resolved) + checkEvaluation(ret, Seq.empty) + } + { val ret = cast(array, ArrayType(BooleanType, containsNull = false)) assert(ret.resolved === false) @@ -1157,6 +1166,19 @@ class CastSuite extends CastSuiteBase { StructType(StructField("a", IntegerType, true) :: Nil))) } + test("SPARK-31227: Non-nullable null type should not coerce to nullable type") { + TypeCoercionSuite.allTypes.foreach { t => + assert(Cast.canCast(ArrayType(NullType, false), ArrayType(t, false))) + + assert(Cast.canCast( + MapType(NullType, NullType, false), MapType(t, t, false))) + + assert(Cast.canCast( + StructType(StructField("a", NullType, false) :: Nil), + StructType(StructField("a", t, false) :: Nil))) + } + } + test("Cast should output null for invalid strings when ANSI is not enabled.") { withSQLConf(SQLConf.ANSI_ENABLED.key -> "false") { checkEvaluation(cast("abdef", DecimalType.USER_DEFAULT), null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index c41eb98c13ea0..5d368ef43861b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -1532,6 +1532,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { assert(e.getMessage.contains("string, binary or array")) } + test("SPARK-31227: Non-nullable null type should not coerce to nullable type in concat") { + val actual = spark.range(1).selectExpr("concat(array(), array(1)) as arr") + val expected = spark.range(1).selectExpr("array(1) as arr") + checkAnswer(actual, expected) + assert(actual.schema === expected.schema) + } + test("flatten function") { // Test cases with a primitive type val intDF = Seq(