Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -525,14 +525,16 @@ object TypeCoercion {
def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
// Skip nodes who has not been resolved yet,
// as this is an extra rule which should be applied at last.
case e if !e.resolved => e
case e if !e.childrenResolved => e

// Decimal and Double remain the same
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can simplify this:

case e if !e.childrenResolved => e
case d: Divide if d.dataType.isInstanceOf[IntegralType] => ...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems case d: Divide if d.dataType.isInstanceOf[IntegralType] is not equivalent with the code that we have at here?

(also , let's avoid of unnecessary changes)

case d: Divide if d.dataType == DoubleType => d
case d: Divide if d.dataType.isInstanceOf[DecimalType] => d

case Divide(left, right) => Divide(Cast(left, DoubleType), Cast(right, DoubleType))
case Divide(left, right) if isNumeric(left) && isNumeric(right) =>
Divide(Cast(left, DoubleType), Cast(right, DoubleType))
}

private def isNumeric(ex: Expression): Boolean = ex.dataType.isInstanceOf[NumericType]
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,15 +213,14 @@ case class Multiply(left: Expression, right: Expression)
case class Divide(left: Expression, right: Expression)
extends BinaryArithmetic with NullIntolerant {

override def inputType: AbstractDataType = NumericType
override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should also cleanup the divide expression to remove code for integral division.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we add a dedicated expression for integral division? We currently support DIV in our parser: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala#L954

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@hvanhovell Thanks for the catch. I will try to test hive behavior more, like how select 5 div 3.0 behaves in Hive.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like div is from mysql (see http://dev.mysql.com/doc/refman/5.7/en/arithmetic-functions.html).

DIV Integer division
/   Division operator

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we use FractionalType, which also includes FloatType here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Float is casted to Double.


override def symbol: String = "/"
override def decimalMethod: String = "$div"
override def nullable: Boolean = true

private lazy val div: (Any, Any) => Any = dataType match {
case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot
}

override def eval(input: InternalRow): Any = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -345,4 +345,36 @@ class AnalysisSuite extends AnalysisTest {

assertAnalysisSuccess(query)
}

private def assertExpressionType(
expression: Expression,
expectedDataType: DataType): Unit = {
val afterAnalyze =
Project(Seq(Alias(expression, "a")()), OneRowRelation).analyze.expressions.head
if (!afterAnalyze.dataType.equals(expectedDataType)) {
fail(
s"""
|data type of expression $expression doesn't match expected:
|Actual data type:
|${afterAnalyze.dataType}
|
|Expected data type:
|${expectedDataType}
""".stripMargin)
}
}

test("SPARK-15776: test whether Divide expression's data type can be deduced correctly by " +
"analyzer") {
assertExpressionType(sum(Divide(1, 2)), DoubleType)
assertExpressionType(sum(Divide(1.0, 2)), DoubleType)
assertExpressionType(sum(Divide(1, 2.0)), DoubleType)
assertExpressionType(sum(Divide(1.0, 2.0)), DoubleType)
assertExpressionType(sum(Divide(1, 2.0f)), DoubleType)
assertExpressionType(sum(Divide(1.0f, 2)), DoubleType)
assertExpressionType(sum(Divide(1, Decimal(2))), DecimalType(31, 11))
assertExpressionType(sum(Divide(Decimal(1), 2)), DecimalType(31, 11))
assertExpressionType(sum(Divide(Decimal(1), 2.0)), DoubleType)
assertExpressionType(sum(Divide(1.0, Decimal(2.0))), DoubleType)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(Subtract('booleanField, 'booleanField),
"requires (numeric or calendarinterval) type")
assertError(Multiply('booleanField, 'booleanField), "requires numeric type")
assertError(Divide('booleanField, 'booleanField), "requires numeric type")
assertError(Divide('booleanField, 'booleanField), "requires (double or decimal) type")
assertError(Remainder('booleanField, 'booleanField), "requires numeric type")

assertError(BitwiseAnd('booleanField, 'booleanField), "requires integral type")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ package org.apache.spark.sql.catalyst.analysis

import java.sql.Timestamp

import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{Division, FunctionArgumentConversion}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

Expand Down Expand Up @@ -199,9 +201,20 @@ class TypeCoercionSuite extends PlanTest {
}

private def ruleTest(rule: Rule[LogicalPlan], initial: Expression, transformed: Expression) {
ruleTest(Seq(rule), initial, transformed)
}

private def ruleTest(
rules: Seq[Rule[LogicalPlan]],
initial: Expression,
transformed: Expression): Unit = {
val testRelation = LocalRelation(AttributeReference("a", IntegerType)())
val analyzer = new RuleExecutor[LogicalPlan] {
override val batches = Seq(Batch("Resolution", FixedPoint(3), rules: _*))
}

comparePlans(
rule(Project(Seq(Alias(initial, "a")()), testRelation)),
analyzer.execute(Project(Seq(Alias(initial, "a")()), testRelation)),
Project(Seq(Alias(transformed, "a")()), testRelation))
}

Expand Down Expand Up @@ -630,6 +643,26 @@ class TypeCoercionSuite extends PlanTest {
Seq(Cast(Literal(1), StringType), Cast(Literal("b"), StringType)))
)
}

test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " +
"in aggregation function like sum") {
val rules = Seq(FunctionArgumentConversion, Division)
// Casts Integer to Double
ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType))))
// Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will
// cast the right expression to Double.
ruleTest(rules, sum(Divide(4.0, 3)), sum(Divide(4.0, 3)))
// Left expression is Int, right expression is Double
ruleTest(rules, sum(Divide(4, 3.0)), sum(Divide(Cast(4, DoubleType), Cast(3.0, DoubleType))))
// Casts Float to Double
ruleTest(
rules,
sum(Divide(4.0f, 3)),
sum(Divide(Cast(4.0f, DoubleType), Cast(3, DoubleType))))
// Left expression is Decimal, right expression is Int. Another rule DecimalPrecision will cast
// the right expression to Decimal.
ruleTest(rules, sum(Divide(Decimal(4.0), 3)), sum(Divide(Decimal(4.0), 3)))
}
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,13 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
}
}

private def testDecimalAndDoubleType(testFunc: (Int => Any) => Unit): Unit = {
testFunc(_.toDouble)
testFunc(Decimal(_))
}

test("/ (Divide) basic") {
testNumericDataTypes { convert =>
testDecimalAndDoubleType { convert =>
val left = Literal(convert(2))
val right = Literal(convert(1))
val dataType = left.dataType
Expand All @@ -128,12 +133,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Divide(left, Literal(convert(0))), null) // divide by zero
}

DataTypeTestUtils.numericTypeWithoutDecimal.foreach { tpe =>
Seq(DoubleType, DecimalType.SYSTEM_DEFAULT).foreach { tpe =>
checkConsistencyBetweenInterpretedAndCodegen(Divide, tpe, tpe)
}
}

test("/ (Divide) for integral type") {
// By fixing SPARK-15776, Divide's inputType is required to be DoubleType of DecimalType.
// TODO: in future release, we should add a IntegerDivide to support integral types.
ignore("/ (Divide) for integral type") {
checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte)
checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort)
checkEvaluation(Divide(Literal(1), Literal(2)), 0)
Expand All @@ -143,12 +150,6 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(Divide(positiveLongLit, negativeLongLit), 0L)
}

test("/ (Divide) for floating point") {
checkEvaluation(Divide(Literal(1.0f), Literal(2.0f)), 0.5f)
checkEvaluation(Divide(Literal(1.0), Literal(2.0)), 0.5)
checkEvaluation(Divide(Literal(Decimal(1.0)), Literal(Decimal(2.0))), Decimal(0.5))
}

test("% (Remainder)") {
testNumericDataTypes { convert =>
val left = Literal(convert(1))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ class ConstraintPropagationSuite extends SparkFunSuite {
Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) ===
Cast(resolveColumn(tr, "c"), LongType),
Cast(resolveColumn(tr, "d"), DoubleType) /
Cast(Cast(10, LongType), DoubleType) ===
Cast(10, DoubleType) ===
Cast(resolveColumn(tr, "e"), DoubleType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
Expand All @@ -312,7 +312,7 @@ class ConstraintPropagationSuite extends SparkFunSuite {
Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >=
Cast(resolveColumn(tr, "c"), LongType),
Cast(resolveColumn(tr, "d"), DoubleType) /
Cast(Cast(10, LongType), DoubleType) <
Cast(10, DoubleType) <
Cast(resolveColumn(tr, "e"), DoubleType),
IsNotNull(resolveColumn(tr, "a")),
IsNotNull(resolveColumn(tr, "b")),
Expand Down