Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,7 @@ object CatalystTypeConverters {
case d: JavaBigInteger => Decimal(d)
case d: Decimal => d
}
if (decimal.changePrecision(dataType.precision, dataType.scale)) {
decimal
} else {
null
}
decimal.toPrecision(dataType.precision, dataType.scale).orNull
}
override def toScala(catalystValue: Decimal): JavaBigDecimal = {
if (catalystValue == null) null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,15 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
if (value.changePrecision(decimalType.precision, decimalType.scale)) value else null
}

/**
* Create new `Decimal` with precision and scale given in `decimalType` (if any),
* returning null if it overflows or creating a new `value` and returning it if successful.
*
*/
private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal =
value.toPrecision(decimalType.precision, decimalType.scale).orNull


private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
case StringType =>
buildCast[UTF8String](_, s => try {
Expand All @@ -356,14 +365,14 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
case _: NumberFormatException => null
})
case BooleanType =>
buildCast[Boolean](_, b => changePrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
buildCast[Boolean](_, b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target))
case DateType =>
buildCast[Int](_, d => null) // date can't cast to decimal in Hive
case TimestampType =>
// Note that we lose precision here.
buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target))
case dt: DecimalType =>
b => changePrecision(b.asInstanceOf[Decimal].clone(), target)
b => toPrecision(b.asInstanceOf[Decimal], target)
case t: IntegralType =>
b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target)
case x: FractionalType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,8 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary

override def nullable: Boolean = true

override def nullSafeEval(input: Any): Any = {
val d = input.asInstanceOf[Decimal].clone()
if (d.changePrecision(dataType.precision, dataType.scale)) {
d
} else {
null
}
}
override def nullSafeEval(input: Any): Any =
input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, eval => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1024,7 +1024,7 @@ abstract class RoundBase(child: Expression, scale: Expression,
child.dataType match {
case _: DecimalType =>
val decimal = input1.asInstanceOf[Decimal]
if (decimal.changePrecision(decimal.precision, _scale, mode)) decimal else null
decimal.toPrecision(decimal.precision, _scale, mode).orNull
case ByteType =>
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
case ShortType =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.lang.{Long => JLong}
import java.math.{BigInteger, MathContext, RoundingMode}

import org.apache.spark.annotation.InterfaceStability
import org.apache.spark.sql.AnalysisException

/**
* A mutable implementation of BigDecimal that can hold a Long if values are small enough.
Expand Down Expand Up @@ -222,6 +223,19 @@ final class Decimal extends Ordered[Decimal] with Serializable {
case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN)
}

/**
* Create new `Decimal` with given precision and scale.
*
* @return `Some(decimal)` if successful or `None` if overflow would occur
*/
private[sql] def toPrecision(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style:

def xxx(
    para1: xxx,
    para2: xxx): XXX

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

precision: Int,
scale: Int,
roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = {
val copy = clone()
if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None
}

/**
* Update precision and scale while keeping our value the same, and return true if successful.
*
Expand Down Expand Up @@ -362,17 +376,15 @@ final class Decimal extends Ordered[Decimal] with Serializable {
def abs: Decimal = if (this.compare(Decimal.ZERO) < 0) this.unary_- else this

def floor: Decimal = if (scale == 0) this else {
val value = this.clone()
value.changePrecision(
DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_FLOOR)
value
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse(
throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
}

def ceil: Decimal = if (scale == 0) this else {
val value = this.clone()
value.changePrecision(
DecimalType.bounded(precision - scale + 1, 0).precision, 0, ROUND_CEILING)
value
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse(
throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
assert(Decimal(Long.MaxValue, 100, 0).toUnscaledLong === Long.MaxValue)
}

test("changePrecision() on compact decimal should respect rounding mode") {
test("changePrecision/toPrecision on compact decimal should respect rounding mode") {
Seq(ROUND_FLOOR, ROUND_CEILING, ROUND_HALF_UP, ROUND_HALF_EVEN).foreach { mode =>
Seq("0.4", "0.5", "0.6", "1.0", "1.1", "1.6", "2.5", "5.5").foreach { n =>
Seq("", "-").foreach { sign =>
Expand All @@ -202,6 +202,12 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
val d = Decimal(unscaled, 8, 1)
assert(d.changePrecision(10, 0, mode))
assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")

val copy = d.toPrecision(10, 0, mode).orNull
assert(copy !== null)
assert(d.ne(copy))
assert(d === copy)
assert(copy.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,18 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext {
)
}

test("round/bround with data frame from a local Seq of Product") {
val df = spark.createDataFrame(Seq(Tuple1(BigDecimal("5.9")))).toDF("value")
checkAnswer(
df.withColumn("value_rounded", round('value)),
Seq(Row(BigDecimal("5.9"), BigDecimal("6")))
)
checkAnswer(
df.withColumn("value_brounded", bround('value)),
Seq(Row(BigDecimal("5.9"), BigDecimal("6")))
)
}

test("exp") {
testOneToOneMathFunction(exp, math.exp)
}
Expand Down