Skip to content

Commit d7a3535

Browse files
committed
[SPARK-9303] Decimal should use java.math.Decimal directly instead of via Scala wrapper
1 parent d4d762f commit d7a3535

File tree

2 files changed

+27
-25
lines changed

2 files changed

+27
-25
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ case class Cast(child: Expression, dataType: DataType)
192192
}
193193

194194
private[this] def decimalToTimestamp(d: Decimal): Long = {
195-
(d.toBigDecimal * 1000000L).longValue()
195+
d.toJavaBigDecimal.multiply(java.math.BigDecimal.valueOf(1000000L)).longValue()
196196
}
197197
private[this] def doubleToTimestamp(d: Double): Any = {
198198
if (d.isNaN || d.isInfinite) null else (d * 1000000L).toLong

sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.sql.types
1919

20+
import java.math.{BigDecimal => JavaBigDecimal}
21+
2022
import org.apache.spark.annotation.DeveloperApi
2123

2224
/**
@@ -30,7 +32,7 @@ import org.apache.spark.annotation.DeveloperApi
3032
final class Decimal extends Ordered[Decimal] with Serializable {
3133
import org.apache.spark.sql.types.Decimal.{BIG_DEC_ZERO, MAX_LONG_DIGITS, POW_10, ROUNDING_MODE}
3234

33-
private var decimalVal: BigDecimal = null
35+
private var decimalVal: JavaBigDecimal = null
3436
private var longVal: Long = 0L
3537
private var _precision: Int = 1
3638
private var _scale: Int = 0
@@ -44,7 +46,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
4446
def set(longVal: Long): Decimal = {
4547
if (longVal <= -POW_10(MAX_LONG_DIGITS) || longVal >= POW_10(MAX_LONG_DIGITS)) {
4648
// We can't represent this compactly as a long without risking overflow
47-
this.decimalVal = BigDecimal(longVal)
49+
this.decimalVal = new JavaBigDecimal(longVal)
4850
this.longVal = 0L
4951
} else {
5052
this.decimalVal = null
@@ -86,7 +88,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
8688
if (precision < 19) {
8789
return null // Requested precision is too low to represent this value
8890
}
89-
this.decimalVal = BigDecimal(unscaled)
91+
this.decimalVal = new JavaBigDecimal(unscaled)
9092
this.longVal = 0L
9193
} else {
9294
val p = POW_10(math.min(precision, MAX_LONG_DIGITS))
@@ -105,7 +107,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
105107
* Set this Decimal to the given BigDecimal value, with a given precision and scale.
106108
*/
107109
def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = {
108-
this.decimalVal = decimal.setScale(scale, ROUNDING_MODE)
110+
this.decimalVal = decimal.setScale(scale, ROUNDING_MODE).underlying()
109111
require(decimalVal.precision <= precision, "Overflowed precision")
110112
this.longVal = 0L
111113
this._precision = precision
@@ -117,7 +119,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
117119
* Set this Decimal to the given BigDecimal value, inheriting its precision and scale.
118120
*/
119121
def set(decimal: BigDecimal): Decimal = {
120-
this.decimalVal = decimal
122+
this.decimalVal = decimal.underlying()
121123
this.longVal = 0L
122124
this._precision = decimal.precision
123125
this._scale = decimal.scale
@@ -135,19 +137,19 @@ final class Decimal extends Ordered[Decimal] with Serializable {
135137
this
136138
}
137139

138-
def toBigDecimal: BigDecimal = {
140+
def toBigDecimal: BigDecimal = BigDecimal(toJavaBigDecimal)
141+
142+
def toJavaBigDecimal: JavaBigDecimal = {
139143
if (decimalVal.ne(null)) {
140144
decimalVal
141145
} else {
142-
BigDecimal(longVal, _scale)
146+
JavaBigDecimal.valueOf(longVal, _scale)
143147
}
144148
}
145149

146-
def toJavaBigDecimal: java.math.BigDecimal = toBigDecimal.underlying()
147-
148150
def toUnscaledLong: Long = {
149151
if (decimalVal.ne(null)) {
150-
decimalVal.underlying().unscaledValue().longValue()
152+
decimalVal.unscaledValue().longValue()
151153
} else {
152154
longVal
153155
}
@@ -164,9 +166,9 @@ final class Decimal extends Ordered[Decimal] with Serializable {
164166
}
165167
}
166168

167-
def toDouble: Double = toBigDecimal.doubleValue()
169+
def toDouble: Double = toJavaBigDecimal.doubleValue()
168170

169-
def toFloat: Float = toBigDecimal.floatValue()
171+
def toFloat: Float = toJavaBigDecimal.floatValue()
170172

171173
def toLong: Long = {
172174
if (decimalVal.eq(null)) {
@@ -208,7 +210,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
208210
longVal *= POW_10(diff)
209211
} else {
210212
// Give up on using Longs; switch to BigDecimal, which we'll modify below
211-
decimalVal = BigDecimal(longVal, _scale)
213+
decimalVal = JavaBigDecimal.valueOf(longVal, _scale)
212214
}
213215
}
214216
// In both cases, we will check whether our precision is okay below
@@ -217,7 +219,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
217219
if (decimalVal.ne(null)) {
218220
// We get here if either we started with a BigDecimal, or we switched to one because we would
219221
// have overflowed our Long; in either case we must rescale decimalVal to the new scale.
220-
val newVal = decimalVal.setScale(scale, ROUNDING_MODE)
222+
val newVal = decimalVal.setScale(scale, ROUNDING_MODE.id)
221223
if (newVal.precision > precision) {
222224
return false
223225
}
@@ -242,7 +244,7 @@ final class Decimal extends Ordered[Decimal] with Serializable {
242244
if (decimalVal.eq(null) && other.decimalVal.eq(null) && _scale == other._scale) {
243245
if (longVal < other.longVal) -1 else if (longVal == other.longVal) 0 else 1
244246
} else {
245-
toBigDecimal.compare(other.toBigDecimal)
247+
toJavaBigDecimal.compareTo(other.toJavaBigDecimal)
246248
}
247249
}
248250

@@ -253,27 +255,27 @@ final class Decimal extends Ordered[Decimal] with Serializable {
253255
false
254256
}
255257

256-
override def hashCode(): Int = toBigDecimal.hashCode()
258+
override def hashCode(): Int = toJavaBigDecimal.hashCode()
257259

258260
def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0
259261

260-
def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal)
262+
def + (that: Decimal): Decimal = Decimal(toJavaBigDecimal.add(that.toJavaBigDecimal))
261263

262-
def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal)
264+
def - (that: Decimal): Decimal = Decimal(toJavaBigDecimal.subtract(that.toJavaBigDecimal))
263265

264-
def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal)
266+
def * (that: Decimal): Decimal = Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal))
265267

266268
def / (that: Decimal): Decimal =
267-
if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal)
269+
if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal))
268270

269271
def % (that: Decimal): Decimal =
270-
if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal)
272+
if (that.isZero) null else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal))
271273

272274
def remainder(that: Decimal): Decimal = this % that
273275

274276
def unary_- : Decimal = {
275277
if (decimalVal.ne(null)) {
276-
Decimal(-decimalVal)
278+
Decimal(decimalVal.negate())
277279
} else {
278280
Decimal(-longVal, precision, scale)
279281
}
@@ -290,7 +292,7 @@ object Decimal {
290292

291293
private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong)
292294

293-
private val BIG_DEC_ZERO = BigDecimal(0)
295+
private val BIG_DEC_ZERO: JavaBigDecimal = JavaBigDecimal.valueOf(0)
294296

295297
def apply(value: Double): Decimal = new Decimal().set(value)
296298

@@ -300,7 +302,7 @@ object Decimal {
300302

301303
def apply(value: BigDecimal): Decimal = new Decimal().set(value)
302304

303-
def apply(value: java.math.BigDecimal): Decimal = new Decimal().set(value)
305+
def apply(value: JavaBigDecimal): Decimal = new Decimal().set(value)
304306

305307
def apply(value: BigDecimal, precision: Int, scale: Int): Decimal =
306308
new Decimal().set(value, precision, scale)

0 commit comments

Comments
 (0)