From d7a35358e2068eca9bdead2b93f3b96dcaf890d8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Thu, 23 Jul 2015 23:02:13 -0700 Subject: [PATCH] [SPARK-9303] Decimal should use java.math.Decimal directly instead of via Scala wrapper --- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../org/apache/spark/sql/types/Decimal.scala | 50 ++++++++++--------- 2 files changed, 27 insertions(+), 25 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index c66854d52c50..d4e319845bf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -192,7 +192,7 @@ case class Cast(child: Expression, dataType: DataType) } private[this] def decimalToTimestamp(d: Decimal): Long = { - (d.toBigDecimal * 1000000L).longValue() + d.toJavaBigDecimal.multiply(java.math.BigDecimal.valueOf(1000000L)).longValue() } private[this] def doubleToTimestamp(d: Double): Any = { if (d.isNaN || d.isInfinite) null else (d * 1000000L).toLong diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index bc689810bc29..3e99d2999ca2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.types +import java.math.{BigDecimal => JavaBigDecimal} + import org.apache.spark.annotation.DeveloperApi /** @@ -30,7 +32,7 @@ import org.apache.spark.annotation.DeveloperApi final class Decimal extends Ordered[Decimal] with Serializable { import org.apache.spark.sql.types.Decimal.{BIG_DEC_ZERO, MAX_LONG_DIGITS, POW_10, ROUNDING_MODE} - private var decimalVal: BigDecimal = null + private var decimalVal: JavaBigDecimal = null private var longVal: Long = 0L private var _precision: Int = 1 private var _scale: Int = 0 @@ -44,7 +46,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { def set(longVal: Long): Decimal = { if (longVal <= -POW_10(MAX_LONG_DIGITS) || longVal >= POW_10(MAX_LONG_DIGITS)) { // We can't represent this compactly as a long without risking overflow - this.decimalVal = BigDecimal(longVal) + this.decimalVal = new JavaBigDecimal(longVal) this.longVal = 0L } else { this.decimalVal = null @@ -86,7 +88,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (precision < 19) { return null // Requested precision is too low to represent this value } - this.decimalVal = BigDecimal(unscaled) + this.decimalVal = new JavaBigDecimal(unscaled) this.longVal = 0L } else { val p = POW_10(math.min(precision, MAX_LONG_DIGITS)) @@ -105,7 +107,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { * Set this Decimal to the given BigDecimal value, with a given precision and scale. */ def set(decimal: BigDecimal, precision: Int, scale: Int): Decimal = { - this.decimalVal = decimal.setScale(scale, ROUNDING_MODE) + this.decimalVal = decimal.setScale(scale, ROUNDING_MODE).underlying() require(decimalVal.precision <= precision, "Overflowed precision") this.longVal = 0L this._precision = precision @@ -117,7 +119,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { * Set this Decimal to the given BigDecimal value, inheriting its precision and scale. */ def set(decimal: BigDecimal): Decimal = { - this.decimalVal = decimal + this.decimalVal = decimal.underlying() this.longVal = 0L this._precision = decimal.precision this._scale = decimal.scale @@ -135,19 +137,19 @@ final class Decimal extends Ordered[Decimal] with Serializable { this } - def toBigDecimal: BigDecimal = { + def toBigDecimal: BigDecimal = BigDecimal(toJavaBigDecimal) + + def toJavaBigDecimal: JavaBigDecimal = { if (decimalVal.ne(null)) { decimalVal } else { - BigDecimal(longVal, _scale) + JavaBigDecimal.valueOf(longVal, _scale) } } - def toJavaBigDecimal: java.math.BigDecimal = toBigDecimal.underlying() - def toUnscaledLong: Long = { if (decimalVal.ne(null)) { - decimalVal.underlying().unscaledValue().longValue() + decimalVal.unscaledValue().longValue() } else { longVal } @@ -164,9 +166,9 @@ final class Decimal extends Ordered[Decimal] with Serializable { } } - def toDouble: Double = toBigDecimal.doubleValue() + def toDouble: Double = toJavaBigDecimal.doubleValue() - def toFloat: Float = toBigDecimal.floatValue() + def toFloat: Float = toJavaBigDecimal.floatValue() def toLong: Long = { if (decimalVal.eq(null)) { @@ -208,7 +210,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { longVal *= POW_10(diff) } else { // Give up on using Longs; switch to BigDecimal, which we'll modify below - decimalVal = BigDecimal(longVal, _scale) + decimalVal = JavaBigDecimal.valueOf(longVal, _scale) } } // In both cases, we will check whether our precision is okay below @@ -217,7 +219,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.ne(null)) { // We get here if either we started with a BigDecimal, or we switched to one because we would // have overflowed our Long; in either case we must rescale decimalVal to the new scale. - val newVal = decimalVal.setScale(scale, ROUNDING_MODE) + val newVal = decimalVal.setScale(scale, ROUNDING_MODE.id) if (newVal.precision > precision) { return false } @@ -242,7 +244,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { if (decimalVal.eq(null) && other.decimalVal.eq(null) && _scale == other._scale) { if (longVal < other.longVal) -1 else if (longVal == other.longVal) 0 else 1 } else { - toBigDecimal.compare(other.toBigDecimal) + toJavaBigDecimal.compareTo(other.toJavaBigDecimal) } } @@ -253,27 +255,27 @@ final class Decimal extends Ordered[Decimal] with Serializable { false } - override def hashCode(): Int = toBigDecimal.hashCode() + override def hashCode(): Int = toJavaBigDecimal.hashCode() def isZero: Boolean = if (decimalVal.ne(null)) decimalVal == BIG_DEC_ZERO else longVal == 0 - def + (that: Decimal): Decimal = Decimal(toBigDecimal + that.toBigDecimal) + def + (that: Decimal): Decimal = Decimal(toJavaBigDecimal.add(that.toJavaBigDecimal)) - def - (that: Decimal): Decimal = Decimal(toBigDecimal - that.toBigDecimal) + def - (that: Decimal): Decimal = Decimal(toJavaBigDecimal.subtract(that.toJavaBigDecimal)) - def * (that: Decimal): Decimal = Decimal(toBigDecimal * that.toBigDecimal) + def * (that: Decimal): Decimal = Decimal(toJavaBigDecimal.multiply(that.toJavaBigDecimal)) def / (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal / that.toBigDecimal) + if (that.isZero) null else Decimal(toJavaBigDecimal.divide(that.toJavaBigDecimal)) def % (that: Decimal): Decimal = - if (that.isZero) null else Decimal(toBigDecimal % that.toBigDecimal) + if (that.isZero) null else Decimal(toJavaBigDecimal.remainder(that.toJavaBigDecimal)) def remainder(that: Decimal): Decimal = this % that def unary_- : Decimal = { if (decimalVal.ne(null)) { - Decimal(-decimalVal) + Decimal(decimalVal.negate()) } else { Decimal(-longVal, precision, scale) } @@ -290,7 +292,7 @@ object Decimal { private val POW_10 = Array.tabulate[Long](MAX_LONG_DIGITS + 1)(i => math.pow(10, i).toLong) - private val BIG_DEC_ZERO = BigDecimal(0) + private val BIG_DEC_ZERO: JavaBigDecimal = JavaBigDecimal.valueOf(0) def apply(value: Double): Decimal = new Decimal().set(value) @@ -300,7 +302,7 @@ object Decimal { def apply(value: BigDecimal): Decimal = new Decimal().set(value) - def apply(value: java.math.BigDecimal): Decimal = new Decimal().set(value) + def apply(value: JavaBigDecimal): Decimal = new Decimal().set(value) def apply(value: BigDecimal, precision: Int, scale: Int): Decimal = new Decimal().set(value, precision, scale)