Skip to content

Commit c1990b5

Browse files
committed
Division operator support integral division
1 parent 7548a88 commit c1990b5

File tree

3 files changed

+41
-5
lines changed

3 files changed

+41
-5
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ object TypeCoercion {
5959
CaseWhenCoercion ::
6060
IfCoercion ::
6161
StackCoercion ::
62-
Division ::
62+
Division(conf) ::
6363
ImplicitTypeCasts ::
6464
DateTimeOperations ::
6565
WindowFrameCoercion ::
@@ -666,7 +666,7 @@ object TypeCoercion {
666666
* Hive only performs integral division with the DIV operator. The arguments to / are always
667667
* converted to fractional types.
668668
*/
669-
object Division extends TypeCoercionRule {
669+
case class Division(conf: SQLConf) extends TypeCoercionRule {
670670
override protected def coerceTypes(
671671
plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
672672
// Skip nodes who has not been resolved yet,
@@ -677,7 +677,12 @@ object TypeCoercion {
677677
case d: Divide if d.dataType == DoubleType => d
678678
case d: Divide if d.dataType.isInstanceOf[DecimalType] => d
679679
case Divide(left, right) if isNumericOrNull(left) && isNumericOrNull(right) =>
680-
Divide(Cast(left, DoubleType), Cast(right, DoubleType))
680+
(left.dataType, right.dataType) match {
681+
case (_: IntegralType, _: IntegralType) if conf.preferIntegralDivision =>
682+
IntegralDivide(left, right)
683+
case _ =>
684+
Divide(Cast(left, DoubleType), Cast(right, DoubleType))
685+
}
681686
}
682687

683688
private def isNumericOrNull(ex: Expression): Boolean = {

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1524,6 +1524,12 @@ object SQLConf {
15241524
.booleanConf
15251525
.createWithDefault(false)
15261526

1527+
val PREFER_INTEGRAL_DIVISION = buildConf("spark.sql.function.preferIntegralDivision")
1528+
.doc("When true, will perform integral division with the / operator " +
1529+
"if both sides are integral types.")
1530+
.booleanConf
1531+
.createWithDefault(false)
1532+
15271533
val ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION =
15281534
buildConf("spark.sql.legacy.allowCreatingManagedTableUsingNonemptyLocation")
15291535
.internal()
@@ -2294,6 +2300,8 @@ class SQLConf extends Serializable with Logging {
22942300

22952301
def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING)
22962302

2303+
def preferIntegralDivision: Boolean = getConf(PREFER_INTEGRAL_DIVISION)
2304+
22972305
def allowCreatingManagedTableUsingNonemptyLocation: Boolean =
22982306
getConf(ALLOW_CREATING_MANAGED_TABLE_USING_NONEMPTY_LOCATION)
22992307

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1456,7 +1456,7 @@ class TypeCoercionSuite extends AnalysisTest {
14561456

14571457
test("SPARK-15776 Divide expression's dataType should be casted to Double or Decimal " +
14581458
"in aggregation function like sum") {
1459-
val rules = Seq(FunctionArgumentConversion, Division)
1459+
val rules = Seq(FunctionArgumentConversion, Division(conf))
14601460
// Casts Integer to Double
14611461
ruleTest(rules, sum(Divide(4, 3)), sum(Divide(Cast(4, DoubleType), Cast(3, DoubleType))))
14621462
// Left expression is Double, right expression is Int. Another rule ImplicitTypeCasts will
@@ -1475,12 +1475,35 @@ class TypeCoercionSuite extends AnalysisTest {
14751475
}
14761476

14771477
test("SPARK-17117 null type coercion in divide") {
1478-
val rules = Seq(FunctionArgumentConversion, Division, ImplicitTypeCasts)
1478+
val rules = Seq(FunctionArgumentConversion, Division(conf), ImplicitTypeCasts)
14791479
val nullLit = Literal.create(null, NullType)
14801480
ruleTest(rules, Divide(1L, nullLit), Divide(Cast(1L, DoubleType), Cast(nullLit, DoubleType)))
14811481
ruleTest(rules, Divide(nullLit, 1L), Divide(Cast(nullLit, DoubleType), Cast(1L, DoubleType)))
14821482
}
14831483

1484+
test("SPARK-28395 Division operator support integral division") {
1485+
val rules = Seq(FunctionArgumentConversion, Division(conf))
1486+
Seq(true, false).foreach { preferIntegralDivision =>
1487+
withSQLConf(SQLConf.PREFER_INTEGRAL_DIVISION.key -> s"$preferIntegralDivision") {
1488+
val result1 = if (preferIntegralDivision) {
1489+
IntegralDivide(1L, 1L)
1490+
} else {
1491+
Divide(Cast(1L, DoubleType), Cast(1L, DoubleType))
1492+
}
1493+
ruleTest(rules, Divide(1L, 1L), result1)
1494+
val result2 = if (preferIntegralDivision) {
1495+
IntegralDivide(1, Cast(1, ShortType))
1496+
} else {
1497+
Divide(Cast(1, DoubleType), Cast(Cast(1, ShortType), DoubleType))
1498+
}
1499+
ruleTest(rules, Divide(1, Cast(1, ShortType)), result2)
1500+
1501+
ruleTest(rules, Divide(1L, 1D), Divide(Cast(1L, DoubleType), Cast(1D, DoubleType)))
1502+
ruleTest(rules, Divide(Decimal(1.1), 1L), Divide(Decimal(1.1), 1L))
1503+
}
1504+
}
1505+
}
1506+
14841507
test("binary comparison with string promotion") {
14851508
val rule = TypeCoercion.PromoteStrings(conf)
14861509
ruleTest(rule,

0 commit comments

Comments
 (0)