From df08eeacd85187ca5a71463fc5d25f63426ebe84 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Mon, 13 Jun 2016 15:09:20 -0700 Subject: [PATCH 1/6] SPARK-15776 Divide Expression inside an Aggregation function is casted to wrong type --- .../spark/sql/catalyst/analysis/TypeCoercion.scala | 2 +- .../spark/sql/catalyst/expressions/arithmetic.scala | 2 +- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 11 +++++++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index a5b5b91e4ab3a..cc5dfcd2db901 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -525,7 +525,7 @@ object TypeCoercion { def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, // as this is an extra rule which should be applied at last. - case e if !e.resolved => e + case e if !e.childrenResolved => e // Decimal and Double remain the same case d: Divide if d.dataType == DoubleType => d diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index b2df79a58884b..137607c6b592a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -213,7 +213,7 @@ case class Multiply(left: Expression, right: Expression) case class Divide(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { - override def inputType: AbstractDataType = NumericType + override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) override def symbol: String = "/" override def decimalMethod: String = "$div" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 89f868509965e..695d9084780c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2847,4 +2847,15 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-15887: hive-site.xml should be loaded") { assert(spark.sessionState.newHadoopConf().get("hive.in.test") == "true") } + + test("SPARK-15776 Divide expression inside an Aggregation function should not " + + "be casted to wrong type") { + val doubleSchema = StructType(StructField("a", DoubleType, true) :: Nil) + assert(sql("select sum(4/3) as a").schema == doubleSchema) + assert(sql("select sum(cast(4.0 as double) / 3) as a").schema == doubleSchema) + assert(sql("select sum(cast(4.0 as float) / 3) as a").schema == doubleSchema) + + val decimalSchema = StructType(StructField("a", DecimalType(31, 11), true) :: Nil) + assert(sql("select sum(cast(4.0 as decimal) / 3) as a").schema == decimalSchema) + } } From bce5ea71f36641233a8122178724dcd2578873d3 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Tue, 14 Jun 2016 13:28:58 -0700 Subject: [PATCH 2/6] change the UT, fix comments --- .../sql/catalyst/expressions/arithmetic.scala | 1 - .../catalyst/analysis/TypeCoercionSuite.scala | 31 ++++++++++++++++++- .../org/apache/spark/sql/SQLQuerySuite.scala | 11 ------- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 137607c6b592a..4db1352291e0b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -221,7 +221,6 @@ case class Divide(left: Expression, right: Expression) private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div - case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot } override def eval(input: InternalRow): Any = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 7435399b1492a..a002a633f8756 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.catalyst.analysis import java.sql.Timestamp +import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{Division, FunctionArgumentConversion} +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -205,6 +207,14 @@ class TypeCoercionSuite extends PlanTest { Project(Seq(Alias(transformed, "a")()), testRelation)) } + private def ruleTest( + rule: RuleExecutor[LogicalPlan], initial: Expression, transformed: Expression): Unit = { + val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + comparePlans( + rule.execute(Project(Seq(Alias(initial, "a")()), testRelation)), + Project(Seq(Alias(transformed, "a")()), testRelation)) + } + test("cast NullType for expressions that implement ExpectsInputTypes") { import TypeCoercionSuite._ @@ -630,6 +640,25 @@ class TypeCoercionSuite extends PlanTest { Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType))) ) } + + test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal") { + val analyzer = new RuleExecutor[LogicalPlan] { + override val batches = + Seq(Batch("Resolution", FixedPoint(10), FunctionArgumentConversion, Division)) + } + + // Cast integer to double + ruleTest(analyzer, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) + // left expression is already Double, skip + ruleTest(analyzer, sum(Divide(4.0, 3)), sum(Divide(4.0, 3))) + // Cast Float to Double + ruleTest( + analyzer, + sum(Divide(4.0f, 3)), + sum(Divide(Cast(4.0f, DoubleType), Cast(3, DoubleType)))) + // left expression is already Decimal, skip + ruleTest(analyzer, sum(Divide(Decimal(4.0), 3)), sum(Divide(Decimal(4.0), 3))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 695d9084780c2..89f868509965e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2847,15 +2847,4 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("SPARK-15887: hive-site.xml should be loaded") { assert(spark.sessionState.newHadoopConf().get("hive.in.test") == "true") } - - test("SPARK-15776 Divide expression inside an Aggregation function should not " + - "be casted to wrong type") { - val doubleSchema = StructType(StructField("a", DoubleType, true) :: Nil) - assert(sql("select sum(4/3) as a").schema == doubleSchema) - assert(sql("select sum(cast(4.0 as double) / 3) as a").schema == doubleSchema) - assert(sql("select sum(cast(4.0 as float) / 3) as a").schema == doubleSchema) - - val decimalSchema = StructType(StructField("a", DecimalType(31, 11), true) :: Nil) - assert(sql("select sum(cast(4.0 as decimal) / 3) as a").schema == decimalSchema) - } } From 659f17c29137d4b8d1fbdf811b9bd134e00ad0a4 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Tue, 14 Jun 2016 15:44:36 -0700 Subject: [PATCH 3/6] fix UT --- .../sql/catalyst/analysis/TypeCoercion.scala | 6 ++-- .../ExpressionTypeCheckingSuite.scala | 2 +- .../catalyst/analysis/TypeCoercionSuite.scala | 35 +++++++++---------- .../ArithmeticExpressionSuite.scala | 30 +++++++++++----- .../plans/ConstraintPropagationSuite.scala | 4 +-- 5 files changed, 45 insertions(+), 32 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index cc5dfcd2db901..16df628a5730c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -530,9 +530,11 @@ object TypeCoercion { // Decimal and Double remain the same case d: Divide if d.dataType == DoubleType => d case d: Divide if d.dataType.isInstanceOf[DecimalType] => d - - case Divide(left, right) => Divide(Cast(left, DoubleType), Cast(right, DoubleType)) + case Divide(left, right) if isNumeric(left) && isNumeric(right) => + Divide(Cast(left, DoubleType), Cast(right, DoubleType)) } + + private def isNumeric(ex: Expression): Boolean = ex.dataType.isInstanceOf[NumericType] } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 660dc86c3e284..54436ea9a4a72 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -85,7 +85,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Subtract('booleanField, 'booleanField), "requires (numeric or calendarinterval) type") assertError(Multiply('booleanField, 'booleanField), "requires numeric type") - assertError(Divide('booleanField, 'booleanField), "requires numeric type") + assertError(Divide('booleanField, 'booleanField), "requires (double or decimal) type") assertError(Remainder('booleanField, 'booleanField), "requires numeric type") assertError(BitwiseAnd('booleanField, 'booleanField), "requires integral type") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index a002a633f8756..43e80a9dddef3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -201,17 +201,18 @@ class TypeCoercionSuite extends PlanTest { } private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) { - val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) - comparePlans( - rule(Project(Seq(Alias(initial, "a")()), testRelation)), - Project(Seq(Alias(transformed, "a")()), testRelation)) + ruleTest(Seq(rule), initial, transformed) } private def ruleTest( - rule: RuleExecutor[LogicalPlan], initial: Expression, transformed: Expression): Unit = { + rules: Seq[Rule[LogicalPlan]], initial: Expression, transformed: Expression): Unit = { val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + val analyzer = new RuleExecutor[LogicalPlan] { + override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*)) + } + comparePlans( - rule.execute(Project(Seq(Alias(initial, "a")()), testRelation)), + analyzer.execute(Project(Seq(Alias(initial, "a")()), testRelation)), Project(Seq(Alias(transformed, "a")()), testRelation)) } @@ -642,22 +643,18 @@ class TypeCoercionSuite extends PlanTest { } test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal") { - val analyzer = new RuleExecutor[LogicalPlan] { - override val batches = - Seq(Batch("Resolution", FixedPoint(10), FunctionArgumentConversion, Division)) - } - - // Cast integer to double - ruleTest(analyzer, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) - // left expression is already Double, skip - ruleTest(analyzer, sum(Divide(4.0, 3)), sum(Divide(4.0, 3))) - // Cast Float to Double + val rules = Seq(FunctionArgumentConversion, Division) + // Casts integer to double + ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) + // Left expression is already Double, skip + ruleTest(rules, sum(Divide(4.0, 3)), sum(Divide(4.0, 3))) + // Casts Float to Double ruleTest( - analyzer, + rules, sum(Divide(4.0f, 3)), sum(Divide(Cast(4.0f, DoubleType), Cast(3, DoubleType)))) - // left expression is already Decimal, skip - ruleTest(analyzer, sum(Divide(Decimal(4.0), 3)), sum(Divide(Decimal(4.0), 3))) + // Lefts expression is already Decimal, skip + ruleTest(rules, sum(Divide(Decimal(4.0), 3)), sum(Divide(Decimal(4.0), 3))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 72285c6a24199..ccd5390b1bd37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -117,8 +117,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } + private def testDecimalAndDoubleType(testFunc: (Int => Any) => Unit): Unit = { + testFunc(_.toDouble) + testFunc(Decimal(_)) + } + test("/ (Divide) basic") { - testNumericDataTypes { convert => + testDecimalAndDoubleType { convert => val left = Literal(convert(2)) val right = Literal(convert(1)) val dataType = left.dataType @@ -133,7 +138,22 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } - test("/ (Divide) for integral type") { + test("SPARK-15776: Divide report unresolved when children's data type is " + + "neither DecimalType nor DoubleType") { + assert(!(Divide(Literal(1.toByte), Literal(2.toByte)).resolved)) + assert(!(Divide(Literal(1.toShort), Literal(2.toShort)).resolved)) + assert(!(Divide(Literal(1), Literal(2)).resolved)) + assert(!(Divide(Literal(1L), Literal(2L)).resolved)) + assert(!(Divide(Literal(1.0f), Literal(2.0f)).resolved)) + + // Resolved if children's dataType is DoubleType or DecimalType + assert(Divide(Literal(1.0), Literal(2.0)).resolved) + assert(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))).resolved) + } + + // By fixing SPARK-15776, Divide's inputType is required to be DoubleType of DecimalType. + // TODO: in future release, we should add a IntegerDivide to support integral types. + ignore("/ (Divide) for integral type") { checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte) checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort) checkEvaluation(Divide(Literal(1), Literal(2)), 0) @@ -143,12 +163,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Divide(positiveLongLit, negativeLongLit), 0L) } - test("/ (Divide) for floating point") { - checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f) - checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5) - checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5)) - } - test("% (Remainder)") { testNumericDataTypes { convert => val left = Literal(convert(1)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 81cc6b123cdd4..0b73b5e009b79 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -298,7 +298,7 @@ class ConstraintPropagationSuite extends SparkFunSuite { Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) === Cast(resolveColumn(tr, "c"), LongType), Cast(resolveColumn(tr, "d"), DoubleType) / - Cast(Cast(10, LongType), DoubleType) === + Cast(10, DoubleType) === Cast(resolveColumn(tr, "e"), DoubleType), IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "b")), @@ -312,7 +312,7 @@ class ConstraintPropagationSuite extends SparkFunSuite { Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >= Cast(resolveColumn(tr, "c"), LongType), Cast(resolveColumn(tr, "d"), DoubleType) / - Cast(Cast(10, LongType), DoubleType) < + Cast(10, DoubleType) < Cast(resolveColumn(tr, "e"), DoubleType), IsNotNull(resolveColumn(tr, "a")), IsNotNull(resolveColumn(tr, "b")), From ff3aa3b891960755dc6134939d6eb8c64e5120e9 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Tue, 14 Jun 2016 16:26:21 -0700 Subject: [PATCH 4/6] on wenchen's comment --- .../catalyst/analysis/TypeCoercionSuite.scala | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 43e80a9dddef3..1687d44f8010b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -642,19 +642,27 @@ class TypeCoercionSuite extends PlanTest { ) } - test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal") { + test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + + "in aggregation function") { val rules = Seq(FunctionArgumentConversion, Division) - // Casts integer to double + // Casts Integer to Double ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) - // Left expression is already Double, skip + // Left expression is Double, right expression is Int ruleTest(rules, sum(Divide(4.0, 3)), sum(Divide(4.0, 3))) + // Left expression is Int, right expression is Double + ruleTest(rules, sum(Divide(4, 3.0)), sum(Divide(Cast(4, DoubleType), Cast(3.0, DoubleType)))) // Casts Float to Double ruleTest( rules, sum(Divide(4.0f, 3)), sum(Divide(Cast(4.0f, DoubleType), Cast(3, DoubleType)))) - // Lefts expression is already Decimal, skip + // Left expression is Decimal, right expression is Int ruleTest(rules, sum(Divide(Decimal(4.0), 3)), sum(Divide(Decimal(4.0), 3))) + // Left expression is Int, right expression is Decimal + ruleTest( + rules, + sum(Divide(4, Decimal(3.0))), + sum(Divide(Cast(4, DoubleType), Cast(Decimal(3.0), DoubleType)))) } } From 3735411d8dc6210098939721f63aeecbe93cb873 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Tue, 14 Jun 2016 21:38:45 -0700 Subject: [PATCH 5/6] fix UT failure --- .../sql/catalyst/expressions/ArithmeticExpressionSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index ccd5390b1bd37..46e9f9ba4dbc7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -133,7 +133,7 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero } - DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe => + Seq(DoubleType, DecimalType.SYSTEM_DEFAULT).foreach { tpe => checkConsistencyBetweenInterpretedAndCodegen(Divide, tpe, tpe) } } From 8fc6ba78a7b81d3a14f77c657f62ebf81403c102 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Wed, 15 Jun 2016 11:42:06 -0700 Subject: [PATCH 6/6] add more UT --- .../sql/catalyst/analysis/AnalysisSuite.scala | 32 +++++++++++++++++++ .../catalyst/analysis/TypeCoercionSuite.scala | 17 +++++----- .../ArithmeticExpressionSuite.scala | 13 -------- 3 files changed, 40 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 77ea29ead92cc..102c78bd72111 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -345,4 +345,36 @@ class AnalysisSuite extends AnalysisTest { assertAnalysisSuccess(query) } + + private def assertExpressionType( + expression: Expression, + expectedDataType: DataType): Unit = { + val afterAnalyze = + Project(Seq(Alias(expression, "a")()), OneRowRelation).analyze.expressions.head + if (!afterAnalyze.dataType.equals(expectedDataType)) { + fail( + s""" + |data type of expression $expression doesn't match expected: + |Actual data type: + |${afterAnalyze.dataType} + | + |Expected data type: + |${expectedDataType} + """.stripMargin) + } + } + + test("SPARK-15776: test whether Divide expression's data type can be deduced correctly by " + + "analyzer") { + assertExpressionType(sum(Divide(1, 2)), DoubleType) + assertExpressionType(sum(Divide(1.0, 2)), DoubleType) + assertExpressionType(sum(Divide(1, 2.0)), DoubleType) + assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType) + assertExpressionType(sum(Divide(1, 2.0f)), DoubleType) + assertExpressionType(sum(Divide(1.0f, 2)), DoubleType) + assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11)) + assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11)) + assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType) + assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 1687d44f8010b..971c99b671671 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -205,7 +205,9 @@ class TypeCoercionSuite extends PlanTest { } private def ruleTest( - rules: Seq[Rule[LogicalPlan]], initial: Expression, transformed: Expression): Unit = { + rules: Seq[Rule[LogicalPlan]], + initial: Expression, + transformed: Expression): Unit = { val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) val analyzer = new RuleExecutor[LogicalPlan] { override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*)) @@ -643,11 +645,12 @@ class TypeCoercionSuite extends PlanTest { } test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + - "in aggregation function") { + "in aggregation function like sum") { val rules = Seq(FunctionArgumentConversion, Division) // Casts Integer to Double ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) - // Left expression is Double, right expression is Int + // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will + // cast the right expression to Double. ruleTest(rules, sum(Divide(4.0, 3)), sum(Divide(4.0, 3))) // Left expression is Int, right expression is Double ruleTest(rules, sum(Divide(4, 3.0)), sum(Divide(Cast(4, DoubleType), Cast(3.0, DoubleType)))) @@ -656,13 +659,9 @@ class TypeCoercionSuite extends PlanTest { rules, sum(Divide(4.0f, 3)), sum(Divide(Cast(4.0f, DoubleType), Cast(3, DoubleType)))) - // Left expression is Decimal, right expression is Int + // Left expression is Decimal, right expression is Int. Another rule DecimalPrecision will cast + // the right expression to Decimal. ruleTest(rules, sum(Divide(Decimal(4.0), 3)), sum(Divide(Decimal(4.0), 3))) - // Left expression is Int, right expression is Decimal - ruleTest( - rules, - sum(Divide(4, Decimal(3.0))), - sum(Divide(Cast(4, DoubleType), Cast(Decimal(3.0), DoubleType)))) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 46e9f9ba4dbc7..2e37887fbc822 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -138,19 +138,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } - test("SPARK-15776: Divide report unresolved when children's data type is " + - "neither DecimalType nor DoubleType") { - assert(!(Divide(Literal(1.toByte), Literal(2.toByte)).resolved)) - assert(!(Divide(Literal(1.toShort), Literal(2.toShort)).resolved)) - assert(!(Divide(Literal(1), Literal(2)).resolved)) - assert(!(Divide(Literal(1L), Literal(2L)).resolved)) - assert(!(Divide(Literal(1.0f), Literal(2.0f)).resolved)) - - // Resolved if children's dataType is DoubleType or DecimalType - assert(Divide(Literal(1.0), Literal(2.0)).resolved) - assert(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))).resolved) - } - // By fixing SPARK-15776, Divide's inputType is required to be DoubleType of DecimalType. // TODO: in future release, we should add a IntegerDivide to support integral types. ignore("/ (Divide) for integral type") {