Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ object CatalystTypeConverters {
case d: JavaBigInteger => Decimal(d)
case d: Decimal => d
}
decimal.toPrecision(dataType.precision, dataType.scale).orNull
decimal.toPrecision(dataType.precision, dataType.scale)
}
override def toScala(catalystValue: Decimal): JavaBigDecimal = {
if (catalystValue == null) null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,9 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String
/**
* Create new `Decimal` with precision and scale given in `decimalType` (if any),
* returning null if it overflows or creating a new `value` and returning it if successful.
*
*/
private[this] def toPrecision(value: Decimal, decimalType: DecimalType): Decimal =
value.toPrecision(decimalType.precision, decimalType.scale).orNull
value.toPrecision(decimalType.precision, decimalType.scale)


private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary
override def nullable: Boolean = true

override def nullSafeEval(input: Any): Any =
input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale).orNull
input.asInstanceOf[Decimal].toPrecision(dataType.precision, dataType.scale)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, eval => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1044,7 +1044,7 @@ abstract class RoundBase(child: Expression, scale: Expression,
dataType match {
case DecimalType.Fixed(_, s) =>
val decimal = input1.asInstanceOf[Decimal]
decimal.toPrecision(decimal.precision, s, mode).orNull
decimal.toPrecision(decimal.precision, s, mode)
case ByteType =>
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte
case ShortType =>
Expand Down Expand Up @@ -1076,12 +1076,8 @@ abstract class RoundBase(child: Expression, scale: Expression,
val evaluationCode = dataType match {
case DecimalType.Fixed(_, s) =>
s"""
if (${ce.value}.changePrecision(${ce.value}.precision(), ${s},
java.math.BigDecimal.${modeStr})) {
${ev.value} = ${ce.value};
} else {
${ev.isNull} = true;
}"""
${ev.value} = ${ce.value}.toPrecision(${ce.value}.precision(), $s, Decimal.$modeStr());
${ev.isNull} = ${ev.value} == null;"""
case ByteType =>
if (_scale < 0) {
s"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -234,31 +234,28 @@ final class Decimal extends Ordered[Decimal] with Serializable {
changePrecision(precision, scale, ROUND_HALF_UP)
}

def changePrecision(precision: Int, scale: Int, mode: Int): Boolean = mode match {
case java.math.BigDecimal.ROUND_HALF_UP => changePrecision(precision, scale, ROUND_HALF_UP)
case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN)
}

/**
* Create new `Decimal` with given precision and scale.
*
* @return `Some(decimal)` if successful or `None` if overflow would occur
* @return a non-null `Decimal` value if successful or `null` if overflow would occur.
*/
private[sql] def toPrecision(
precision: Int,
scale: Int,
roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Option[Decimal] = {
roundMode: BigDecimal.RoundingMode.Value = ROUND_HALF_UP): Decimal = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Option is hard to use in java code(the codegen path), so I change the return type to nullable Decimal.

val copy = clone()
if (copy.changePrecision(precision, scale, roundMode)) Some(copy) else None
if (copy.changePrecision(precision, scale, roundMode)) copy else null
}

/**
* Update precision and scale while keeping our value the same, and return true if successful.
*
* @return true if successful, false if overflow would occur
*/
private[sql] def changePrecision(precision: Int, scale: Int,
roundMode: BigDecimal.RoundingMode.Value): Boolean = {
private[sql] def changePrecision(
precision: Int,
scale: Int,
roundMode: BigDecimal.RoundingMode.Value): Boolean = {
// fast path for UnsafeProjection
if (precision == this.precision && scale == this.scale) {
return true
Expand Down Expand Up @@ -393,14 +390,20 @@ final class Decimal extends Ordered[Decimal] with Serializable {

def floor: Decimal = if (scale == 0) this else {
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
toPrecision(newPrecision, 0, ROUND_FLOOR).getOrElse(
throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
val res = toPrecision(newPrecision, 0, ROUND_FLOOR)
if (res == null) {
throw new AnalysisException(s"Overflow when setting precision to $newPrecision")
}
res
}

def ceil: Decimal = if (scale == 0) this else {
val newPrecision = DecimalType.bounded(precision - scale + 1, 0).precision
toPrecision(newPrecision, 0, ROUND_CEILING).getOrElse(
throw new AnalysisException(s"Overflow when setting precision to $newPrecision"))
val res = toPrecision(newPrecision, 0, ROUND_CEILING)
if (res == null) {
throw new AnalysisException(s"Overflow when setting precision to $newPrecision")
}
res
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester {
assert(d.changePrecision(10, 0, mode))
assert(d.toString === bd.setScale(0, mode).toString(), s"num: $sign$n, mode: $mode")

val copy = d.toPrecision(10, 0, mode).orNull
val copy = d.toPrecision(10, 0, mode)
assert(copy !== null)
assert(d.ne(copy))
assert(d === copy)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,18 @@ class MathFunctionsSuite extends QueryTest with SharedSQLContext {
)
}

test("round/bround with table columns") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an end-to-end test for code-gen path. Could we add a unit test case in MathExpressionsSuite?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the tests in MathExpressionsSuite are testing about literals, I think it's better to improve MathExpressionsSuite for attributes in a new PR, instead of doing it in a folllow-up PR. BTW the original PR didn't add test in MathExpressionsSuite either.

withTable("t") {
Seq(BigDecimal("5.9")).toDF("i").write.saveAsTable("t")
checkAnswer(
sql("select i, round(i) from t"),
Seq(Row(BigDecimal("5.9"), BigDecimal("6"))))
checkAnswer(
sql("select i, bround(i) from t"),
Seq(Row(BigDecimal("5.9"), BigDecimal("6"))))
}
}

test("exp") {
testOneToOneMathFunction(exp, math.exp)
}
Expand Down