Skip to content

Commit 8bde03b

Browse files
Davies Liurxin
authored andcommitted
[SPARK-17494][SQL] changePrecision() on compact decimal should respect rounding mode
## What changes were proposed in this pull request? Floor()/Ceil() of decimal is implemented using changePrecision() by passing a rounding mode, but the rounding mode is not respected when the decimal is in compact mode (could fit within a Long). This Update the changePrecision() to respect rounding mode, which could be ROUND_FLOOR, ROUND_CEIL, ROUND_HALF_UP, ROUND_HALF_EVEN. ## How was this patch tested? Added regression tests. Author: Davies Liu <[email protected]> Closes #15154 from davies/decimal_round.
1 parent 3497ebe commit 8bde03b

File tree

2 files changed

+39
-4
lines changed

2 files changed

+39
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,30 @@ final class Decimal extends Ordered[Decimal] with Serializable {
242242
if (scale < _scale) {
243243
// Easier case: we just need to divide our scale down
244244
val diff = _scale - scale
245-
val droppedDigits = longVal % POW_10(diff)
246-
longVal /= POW_10(diff)
247-
if (math.abs(droppedDigits) * 2 >= POW_10(diff)) {
248-
longVal += (if (longVal < 0) -1L else 1L)
245+
val pow10diff = POW_10(diff)
246+
// % and / always round to 0
247+
val droppedDigits = longVal % pow10diff
248+
longVal /= pow10diff
249+
roundMode match {
250+
case ROUND_FLOOR =>
251+
if (droppedDigits < 0) {
252+
longVal += -1L
253+
}
254+
case ROUND_CEILING =>
255+
if (droppedDigits > 0) {
256+
longVal += 1L
257+
}
258+
case ROUND_HALF_UP =>
259+
if (math.abs(droppedDigits) * 2 >= pow10diff) {
260+
longVal += (if (droppedDigits < 0) -1L else 1L)
261+
}
262+
case ROUND_HALF_EVEN =>
263+
val doubled = math.abs(droppedDigits) * 2
264+
if (doubled > pow10diff || doubled == pow10diff && longVal % 2 != 0) {
265+
longVal += (if (droppedDigits < 0) -1L else 1L)
266+
}
267+
case _ =>
268+
sys.error(s"Not supported rounding mode: $roundMode")
249269
}
250270
} else if (scale > _scale) {
251271
// We might be able to multiply longVal by a power of 10 and not overflow, but if not,

sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ package org.apache.spark.sql.types
2020
import org.scalatest.PrivateMethodTester
2121

2222
import org.apache.spark.SparkFunSuite
23+
import org.apache.spark.sql.types.Decimal._
2324

2425
class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
2526
/** Check that a Decimal has the given string representation, precision and scale */
@@ -191,4 +192,18 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
191192
assert(new Decimal().set(100L, 10, 0).toUnscaledLong === 100L)
192193
assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue)
193194
}
195+
196+
test("changePrecision() on compact decimal should respect rounding mode") {
197+
Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode =>
198+
Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n =>
199+
Seq("", "-").foreach { sign =>
200+
val bd = BigDecimal(sign + n)
201+
val unscaled = (bd * 10).toLongExact
202+
val d = Decimal(unscaled, 8, 1)
203+
assert(d.changePrecision(10, 0, mode))
204+
assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")
205+
}
206+
}
207+
}
208+
}
194209
}

0 commit comments

Comments
 (0)