diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 5181dcc786a3d..81c9580d3f28c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -71,6 +71,7 @@ package object dsl { def - (other: Expression): Expression = Subtract(expr, other) def * (other: Expression): Expression = Multiply(expr, other) def / (other: Expression): Expression = Divide(expr, other) + def div (other: Expression): Expression = IntegralDivide(expr, other) def % (other: Expression): Expression = Remainder(expr, other) def & (other: Expression): Expression = BitwiseAnd(expr, other) def | (other: Expression): Expression = BitwiseOr(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 91ffac0ba2a60..ac58da9e90a95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -207,20 +207,12 @@ case class Multiply(left: Expression, right: Expression) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } -@ExpressionDescription( - usage = "a _FUNC_ b - Divides a by b.", - extended = "> SELECT 3 _FUNC_ 2;\n 1.5") -case class Divide(left: Expression, right: Expression) - extends BinaryArithmetic with NullIntolerant { - - override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) - - override def symbol: String = "/" - override def decimalMethod: String = "$div" +abstract class DivideBase extends BinaryArithmetic with NullIntolerant { override def nullable: Boolean = true private lazy val div: (Any, Any) => Any = dataType match { case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div + case i: IntegralType => i.integral.asInstanceOf[Integral[Any]].quot } override def eval(input: InternalRow): Any = { @@ -250,10 +242,11 @@ case class Divide(left: Expression, right: Expression) } val javaType = ctx.javaType(dataType) val divide = if (dataType.isInstanceOf[DecimalType]) { - s"${eval1.value}.$decimalMethod(${eval2.value})" + s"${eval1.value}.$$div(${eval2.value})" } else { - s"($javaType)(${eval1.value} $symbol ${eval2.value})" + s"($javaType)(${eval1.value} / ${eval2.value})" } + if (!left.nullable && !right.nullable) { ev.copy(code = s""" ${eval2.code} @@ -284,6 +277,26 @@ case class Divide(left: Expression, right: Expression) } } +@ExpressionDescription( + usage = "a _FUNC_ b - Fraction Division a by b.", + extended = "> SELECT 3 _FUNC_ 2;\n 1.5") +case class Divide(left: Expression, right: Expression) extends DivideBase { + + override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType) + + override def symbol: String = "/" +} + +@ExpressionDescription( + usage = "a _FUNC_ b - Divides a by b.", + extended = "> SELECT 3 _FUNC_ 2;\n 1") +case class IntegralDivide(left: Expression, right: Expression) extends DivideBase { + + override def inputType: AbstractDataType = IntegralType + + override def symbol: String = "div" +} + @ExpressionDescription( usage = "a _FUNC_ b - Returns the remainder when dividing a by b.") case class Remainder(left: Expression, right: Expression) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index f2cc8d362478a..3234327349a70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -957,7 +957,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case SqlBaseParser.PERCENT => Remainder(left, right) case SqlBaseParser.DIV => - Cast(Divide(left, right), LongType) + IntegralDivide(left, right) case SqlBaseParser.PLUS => Add(left, right) case SqlBaseParser.MINUS => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 2e37887fbc822..fceefa4540a28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -138,16 +138,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } - // By fixing SPARK-15776, Divide's inputType is required to be DoubleType of DecimalType. - // TODO: in future release, we should add a IntegerDivide to support integral types. - ignore("/ (Divide) for integral type") { - checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte) - checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort) - checkEvaluation(Divide(Literal(1), Literal(2)), 0) - checkEvaluation(Divide(Literal(1.toLong), Literal(2.toLong)), 0.toLong) - checkEvaluation(Divide(positiveShortLit, negativeShortLit), 0.toShort) - checkEvaluation(Divide(positiveIntLit, negativeIntLit), 0) - checkEvaluation(Divide(positiveLongLit, negativeLongLit), 0L) + test("/ (Divide) for integral type") { + checkEvaluation(IntegralDivide(Literal(1.toByte), Literal(2.toByte)), 0.toByte) + checkEvaluation(IntegralDivide(Literal(1.toShort), Literal(2.toShort)), 0.toShort) + checkEvaluation(IntegralDivide(Literal(1), Literal(2)), 0) + checkEvaluation(IntegralDivide(Literal(1.toLong), Literal(2.toLong)), 0.toLong) + checkEvaluation(IntegralDivide(positiveShortLit, negativeShortLit), 0.toShort) + checkEvaluation(IntegralDivide(positiveIntLit, negativeIntLit), 0) + checkEvaluation(IntegralDivide(positiveLongLit, negativeLongLit), 0L) } test("% (Remainder)") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index e73592c7afa28..a207c1c1d5a95 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -169,7 +169,7 @@ class ExpressionParserSuite extends PlanTest { // Simple operations assertEqual("a * b", 'a * 'b) assertEqual("a / b", 'a / 'b) - assertEqual("a DIV b", ('a / 'b).cast(LongType)) + assertEqual("a DIV b", ('a div 'b)) assertEqual("a % b", 'a % 'b) assertEqual("a + b", 'a + 'b) assertEqual("a - b", 'a - 'b) @@ -180,7 +180,7 @@ class ExpressionParserSuite extends PlanTest { // Check precedences assertEqual( "a * t | b ^ c & d - e + f % g DIV h / i * k", - 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k))))) + 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g div 'h) / 'i * 'k))))) } test("unary arithmetic expressions") {