diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 1fdec89e258a7..3125f8cb732db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -59,7 +59,7 @@ object TypeCoercion { CaseWhenCoercion :: IfCoercion :: StackCoercion :: - Division :: + Division(conf) :: ImplicitTypeCasts :: DateTimeOperations :: WindowFrameCoercion :: @@ -666,7 +666,7 @@ object TypeCoercion { * Hive only performs integral division with the DIV operator. The arguments to / are always * converted to fractional types. */ - object Division extends TypeCoercionRule { + case class Division(conf: SQLConf) extends TypeCoercionRule { override protected def coerceTypes( plan: LogicalPlan): LogicalPlan = plan resolveExpressions { // Skip nodes who has not been resolved yet, @@ -677,7 +677,12 @@ object TypeCoercion { case d: Divide if d.dataType == DoubleType => d case d: Divide if d.dataType.isInstanceOf[DecimalType] => d case Divide(left, right) if isNumericOrNull(left) && isNumericOrNull(right) => - Divide(Cast(left, DoubleType), Cast(right, DoubleType)) + (left.dataType, right.dataType) match { + case (_: IntegralType, _: IntegralType) if conf.preferIntegralDivision => + IntegralDivide(left, right) + case _ => + Divide(Cast(left, DoubleType), Cast(right, DoubleType)) + } } private def isNumericOrNull(ex: Expression): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f76103e141d0f..57f5128fd4fbe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1524,6 +1524,12 @@ object SQLConf { .booleanConf .createWithDefault(false) + val PREFER_INTEGRAL_DIVISION = buildConf("spark.sql.function.preferIntegralDivision") + .doc("When true, will perform integral division with the / operator " + + "if both sides are integral types.") + .booleanConf + .createWithDefault(false) + val ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION = buildConf("spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation") .internal() @@ -2294,6 +2300,8 @@ class SQLConf extends Serializable with Logging { def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) + def preferIntegralDivision: Boolean = getConf(PREFER_INTEGRAL_DIVISION) + def allowCreatingManagedTableUsingNonemptyLocation: Boolean = getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index a725e4b608a5f..949bb30d15503 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -1456,7 +1456,7 @@ class TypeCoercionSuite extends AnalysisTest { test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " + "in aggregation function like sum") { - val rules = Seq(FunctionArgumentConversion, Division) + val rules = Seq(FunctionArgumentConversion, Division(conf)) // Casts Integer to Double ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType)))) // Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will @@ -1475,12 +1475,35 @@ class TypeCoercionSuite extends AnalysisTest { } test("SPARK-17117 null type coercion in divide") { - val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts) + val rules = Seq(FunctionArgumentConversion, Division(conf), ImplicitTypeCasts) val nullLit = Literal.create(null, NullType) ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType))) ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType))) } + test("SPARK-28395 Division operator support integral division") { + val rules = Seq(FunctionArgumentConversion, Division(conf)) + Seq(true, false).foreach { preferIntegralDivision => + withSQLConf(SQLConf.PREFER_INTEGRAL_DIVISION.key -> s"$preferIntegralDivision") { + val result1 = if (preferIntegralDivision) { + IntegralDivide(1L, 1L) + } else { + Divide(Cast(1L, DoubleType), Cast(1L, DoubleType)) + } + ruleTest(rules, Divide(1L, 1L), result1) + val result2 = if (preferIntegralDivision) { + IntegralDivide(1, Cast(1, ShortType)) + } else { + Divide(Cast(1, DoubleType), Cast(Cast(1, ShortType), DoubleType)) + } + ruleTest(rules, Divide(1, Cast(1, ShortType)), result2) + + ruleTest(rules, Divide(1L, 1D), Divide(Cast(1L, DoubleType), Cast(1D, DoubleType))) + ruleTest(rules, Divide(Decimal(1.1), 1L), Divide(Decimal(1.1), 1L)) + } + } + } + test("binary comparison with string promotion") { val rule = TypeCoercion.PromoteStrings(conf) ruleTest(rule,