Skip to content

Commit 4205a0a

Browse files
committed
Fix inaccuracy precision/scale of Decimal division operation.
1 parent 2848f4d commit 4205a0a

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,14 @@ final class Decimal extends Ordered[Decimal] with Serializable {
145145
}
146146
}
147147

148+
def toLimitedBigDecimal: BigDecimal = {
149+
if (decimalVal.ne(null)) {
150+
decimalVal
151+
} else {
152+
BigDecimal(longVal, _scale)
153+
}
154+
}
155+
148156
def toJavaBigDecimal: java.math.BigDecimal = toBigDecimal.underlying()
149157

150158
def toUnscaledLong: Long = {
@@ -269,9 +277,9 @@ final class Decimal extends Ordered[Decimal] with Serializable {
269277
if (that.isZero) {
270278
null
271279
} else {
272-
// To avoid non-terminating decimal expansion problem, we turn to Java BigDecimal's divide
273-
// with specified ROUNDING_MODE.
274-
Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal, ROUNDING_MODE.id))
280+
// To avoid non-terminating decimal expansion problem, we get scala's BigDecimal with limited
281+
// precision and scala.
282+
Decimal(toLimitedBigDecimal / that.toLimitedBigDecimal)
275283
}
276284
}
277285

sql/catalyst/src/test/scala/org/apache/spark/sql/types/decimal/DecimalSuite.scala

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,14 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
170170

171171
test("fix non-terminating decimal expansion problem") {
172172
val decimal = Decimal(1.0, 10, 3) / Decimal(3.0, 10, 3)
173-
assert(decimal.toString === "0.333")
173+
// The difference between decimal should not be more than 0.001.
174+
assert(decimal.toDouble - 0.333 < 0.001)
175+
}
176+
177+
test("fix loss of precision/scale when doing division operation") {
178+
val a = Decimal(2) / Decimal(3)
179+
assert(a.toDouble < 1.0 && a.toDouble > 0.6)
180+
val b = Decimal(1) / Decimal(8)
181+
assert(b.toDouble === 0.125)
174182
}
175183
}

0 commit comments

Comments
 (0)