diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 77860e1584f42..8b69a47036962 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -267,6 +267,7 @@ object FunctionRegistry { expression[Subtract]("-"), expression[Multiply]("*"), expression[Divide]("/"), + expression[IntegralDivide]("div"), expression[Remainder]("%"), // aggregate functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index d3ccd18d0245e..176ea823b1fcd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -72,6 +72,7 @@ package object dsl { def - (other: Expression): Expression = Subtract(expr, other) def * (other: Expression): Expression = Multiply(expr, other) def / (other: Expression): Expression = Divide(expr, other) + def div (other: Expression): Expression = IntegralDivide(expr, other) def % (other: Expression): Expression = Remainder(expr, other) def & (other: Expression): Expression = BitwiseAnd(expr, other) def | (other: Expression): Expression = BitwiseOr(expr, other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index c827226d58420..1b1808f8366d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -314,6 +314,34 @@ case class Divide(left: Expression, right: Expression) extends DivModLike { override def evalOperation(left: Any, right: Any): Any = div(left, right) } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "expr1 _FUNC_ expr2 - Divide `expr1` by `expr2` rounded to the long integer. It returns NULL if an operand is NULL or `expr2` is 0.", + examples = """ + Examples: + > SELECT 3 _FUNC_ 2; + 1 + """, + since = "2.5.0") +// scalastyle:on line.size.limit +case class IntegralDivide(left: Expression, right: Expression) extends DivModLike { + + override def inputType: AbstractDataType = IntegralType + override def dataType: DataType = LongType + + override def symbol: String = "/" + override def sqlOperator: String = "div" + + private lazy val div: (Any, Any) => Long = left.dataType match { + case i: IntegralType => + val divide = i.integral.asInstanceOf[Integral[Any]].quot _ + val toLong = i.integral.asInstanceOf[Integral[Any]].toLong _ + (x, y) => toLong(divide(x, y)) + } + + override def evalOperation(left: Any, right: Any): Any = div(left, right) +} + @ExpressionDescription( usage = "expr1 _FUNC_ expr2 - Returns the remainder after `expr1`/`expr2`.", examples = """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 7bc1f63e30540..5cfb5dc871041 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1157,7 +1157,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case SqlBaseParser.PERCENT => Remainder(left, right) case SqlBaseParser.DIV => - Cast(Divide(left, right), LongType) + IntegralDivide(left, right) case SqlBaseParser.PLUS => Add(left, right) case SqlBaseParser.MINUS => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala index 9a752af523ffc..c3c4d9ee6b702 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ArithmeticExpressionSuite.scala @@ -143,16 +143,14 @@ class ArithmeticExpressionSuite extends SparkFunSuite with ExpressionEvalHelper } } - // By fixing SPARK-15776, Divide's inputType is required to be DoubleType of DecimalType. - // TODO: in future release, we should add a IntegerDivide to support integral types. - ignore("/ (Divide) for integral type") { - checkEvaluation(Divide(Literal(1.toByte), Literal(2.toByte)), 0.toByte) - checkEvaluation(Divide(Literal(1.toShort), Literal(2.toShort)), 0.toShort) - checkEvaluation(Divide(Literal(1), Literal(2)), 0) - checkEvaluation(Divide(Literal(1.toLong), Literal(2.toLong)), 0.toLong) - checkEvaluation(Divide(positiveShortLit, negativeShortLit), 0.toShort) - checkEvaluation(Divide(positiveIntLit, negativeIntLit), 0) - checkEvaluation(Divide(positiveLongLit, negativeLongLit), 0L) + test("/ (Divide) for integral type") { + checkEvaluation(IntegralDivide(Literal(1.toByte), Literal(2.toByte)), 0L) + checkEvaluation(IntegralDivide(Literal(1.toShort), Literal(2.toShort)), 0L) + checkEvaluation(IntegralDivide(Literal(1), Literal(2)), 0L) + checkEvaluation(IntegralDivide(Literal(1.toLong), Literal(2.toLong)), 0L) + checkEvaluation(IntegralDivide(positiveShortLit, negativeShortLit), 0L) + checkEvaluation(IntegralDivide(positiveIntLit, negativeIntLit), 0L) + checkEvaluation(IntegralDivide(positiveLongLit, negativeLongLit), 0L) } test("% (Remainder)") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 781fc1e957ae0..b4df22c5b29fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -203,7 +203,7 @@ class ExpressionParserSuite extends PlanTest { // Simple operations assertEqual("a * b", 'a * 'b) assertEqual("a / b", 'a / 'b) - assertEqual("a DIV b", ('a / 'b).cast(LongType)) + assertEqual("a DIV b", 'a div 'b) assertEqual("a % b", 'a % 'b) assertEqual("a + b", 'a + 'b) assertEqual("a - b", 'a - 'b) @@ -214,7 +214,7 @@ class ExpressionParserSuite extends PlanTest { // Check precedences assertEqual( "a * t | b ^ c & d - e + f % g DIV h / i * k", - 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k))))) + 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g div 'h) / 'i * 'k))))) } test("unary arithmetic expressions") { diff --git a/sql/core/src/test/resources/sql-tests/results/operators.sql.out b/sql/core/src/test/resources/sql-tests/results/operators.sql.out index 840655b7a6447..2555734756fc4 100644 --- a/sql/core/src/test/resources/sql-tests/results/operators.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/operators.sql.out @@ -157,7 +157,7 @@ NULL -- !query 19 select 5 div 2 -- !query 19 schema -struct +struct<(5 div 2):bigint> -- !query 19 output 2 @@ -165,7 +165,7 @@ struct -- !query 20 select 5 div 0 -- !query 20 schema -struct +struct<(5 div 0):bigint> -- !query 20 output NULL @@ -173,7 +173,7 @@ NULL -- !query 21 select 5 div null -- !query 21 schema -struct +struct<(5 div CAST(NULL AS INT)):bigint> -- !query 21 output NULL @@ -181,7 +181,7 @@ NULL -- !query 22 select null div 5 -- !query 22 schema -struct +struct<(CAST(NULL AS INT) div 5):bigint> -- !query 22 output NULL