From 8b18deb2a6c58b5cb99c6cab0eee9dd69c854993 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 17 Apr 2020 17:22:10 +0900 Subject: [PATCH] Fix --- .../sql/catalyst/analysis/TypeCoercion.scala | 19 +++++++++++++++---- .../catalyst/analysis/TypeCoercionSuite.scala | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index eb9a4d4feb783..c6e3f56766a8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -842,15 +842,26 @@ object TypeCoercion { * Casts types according to the expected input types for [[Expression]]s. */ object ImplicitTypeCasts extends TypeCoercionRule { + + private def canHandleTypeCoercion(leftType: DataType, rightType: DataType): Boolean = { + (leftType, rightType) match { + case (_: DecimalType, NullType) => true + case (NullType, _: DecimalType) => true + case _ => + // If DecimalType operands are involved except for the two cases above, + // DecimalPrecision will handle it. + !leftType.isInstanceOf[DecimalType] && !rightType.isInstanceOf[DecimalType] && + leftType != rightType + } + } + override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who's children have not been resolved yet. case e if !e.childrenResolved => e - // If DecimalType operands are involved, DecimalPrecision will handle it - case b @ BinaryOperator(left, right) if !left.dataType.isInstanceOf[DecimalType] && - !right.dataType.isInstanceOf[DecimalType] && - left.dataType != right.dataType => + case b @ BinaryOperator(left, right) + if canHandleTypeCoercion(left.dataType, right.dataType) => findTightestCommonType(left.dataType, right.dataType).map { commonType => if (b.inputType.acceptsType(commonType)) { // If the expression accepts the tightest common type, cast to that. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index ab21a9ea5ba18..e37555f1c0ec3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1541,6 +1541,24 @@ class TypeCoercionSuite extends AnalysisTest { Multiply(CaseWhen(Seq((EqualTo(1, 2), Cast(1, DecimalType(34, 24)))), Cast(100, DecimalType(34, 24))), Cast(1, IntegerType))) } + + test("SPARK-31468: null types should be casted to decimal types in ImplicitTypeCasts") { + Seq(AnyTypeBinaryOperator(_, _), NumericTypeBinaryOperator(_, _)).foreach { binaryOp => + // binaryOp(decimal, null) case + ruleTest(TypeCoercion.ImplicitTypeCasts, + binaryOp(Literal.create(null, DecimalType.SYSTEM_DEFAULT), + Literal.create(null, NullType)), + binaryOp(Literal.create(null, DecimalType.SYSTEM_DEFAULT), + Cast(Literal.create(null, NullType), DecimalType.SYSTEM_DEFAULT))) + + // binaryOp(null, decimal) case + ruleTest(TypeCoercion.ImplicitTypeCasts, + binaryOp(Literal.create(null, NullType), + Literal.create(null, DecimalType.SYSTEM_DEFAULT)), + binaryOp(Cast(Literal.create(null, NullType), DecimalType.SYSTEM_DEFAULT), + Literal.create(null, DecimalType.SYSTEM_DEFAULT))) + } + } }