Skip to content

Commit 31dfe7c

Browse files
committed
refactor round to make it readable
1 parent 8c7a949 commit 31dfe7c

File tree

3 files changed

+75
-157
lines changed

3 files changed

+75
-157
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 72 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2323
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure}
2424
import org.apache.spark.sql.catalyst.expressions.codegen._
2525
import org.apache.spark.sql.catalyst.InternalRow
26-
import org.apache.spark.sql.catalyst.util.BigDecimalConverter
2726
import org.apache.spark.sql.types._
2827
import org.apache.spark.unsafe.types.UTF8String
2928

@@ -524,144 +523,125 @@ case class Logarithm(left: Expression, right: Expression)
524523
}
525524
}
526525

527-
case class Round(child: Expression, scale: Expression) extends Expression with ExpectsInputTypes {
526+
case class Round(child: Expression, scale: Expression)
527+
extends BinaryExpression with ExpectsInputTypes {
528528

529-
def this(child: Expression) = {
530-
this(child, Literal(0))
531-
}
529+
import BigDecimal.RoundingMode.HALF_UP
530+
531+
def this(child: Expression) = this(child, Literal(0))
532+
533+
override def left: Expression = child
534+
override def right: Expression = scale
532535

533536
override def children: Seq[Expression] = Seq(child, scale)
534537

538+
// round of Decimal would eval to null if it fails to `changePrecision`
535539
override def nullable: Boolean = true
536540

537541
override def foldable: Boolean = child.foldable
538542

539543
override lazy val dataType: DataType = child.dataType match {
540-
case DecimalType.Fixed(p, s) => DecimalType(p, _scale)
541-
case t => t
542-
}
544+
case DecimalType.Fixed(p, s) => DecimalType(p, _scale)
545+
case t => t
546+
}
543547

544-
override def inputTypes: Seq[AbstractDataType] = Seq(
545-
// rely on precedence to implicit cast String into Double
546-
TypeCollection(DecimalType, DoubleType, FloatType, LongType, IntegerType, ShortType, ByteType),
547-
TypeCollection(LongType, IntegerType, ShortType, ByteType))
548+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType)
548549

549550
override def checkInputDataTypes(): TypeCheckResult = {
550-
child.dataType match {
551-
case _: NumericType => // satisfy requirement
552-
case dt =>
553-
return TypeCheckFailure(s"Only numeric type is allowed for ROUND function, got $dt")
554-
}
555-
scale match {
556-
case Literal(value, LongType) =>
557-
if (value.asInstanceOf[Long] < Int.MinValue || value.asInstanceOf[Long] > Int.MaxValue) {
558-
return TypeCheckFailure("ROUND scale argument out of allowed range")
559-
}
560-
case _ =>
561-
if (scale.dataType.isInstanceOf[IntegralType] && scale.foldable) {
562-
// TODO: How to check out of range for foldable LongType Expression
563-
// satisfy requirement
551+
super.checkInputDataTypes() match {
552+
case TypeCheckSuccess =>
553+
if (scale.foldable) {
554+
TypeCheckSuccess
564555
} else {
565-
return TypeCheckFailure("Only foldable Integral Expression " +
566-
s"is allowed for ROUND scale arguments, got ${child.dataType}")
556+
TypeCheckFailure("Only foldable Expression is allowed for scale arguments")
567557
}
558+
case f => f
568559
}
569-
TypeCheckSuccess
570560
}
571561

572562
private lazy val scaleV = scale.eval(EmptyRow)
573563
private lazy val _scale = if (scaleV != null) scaleV.asInstanceOf[Int] else 0
574564

575-
override def eval(input: InternalRow): Any = {
576-
val evalE = child.eval(input)
577-
if (evalE == null || scaleV == null) return null
578-
round(evalE)
579-
}
580-
581-
private lazy val round: (Any) => (Any) = typedRound(child.dataType)
582-
583-
// Using dataType info to find an appropriate round method
584-
private def typedRound(dt: DataType)(x: Any): Any = {
585-
dt match {
565+
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
566+
child.dataType match {
586567
case _: DecimalType =>
587-
val decimal = x.asInstanceOf[Decimal]
568+
val decimal = input1.asInstanceOf[Decimal]
588569
if (decimal.changePrecision(decimal.precision, _scale)) decimal else null
589570
case ByteType =>
590-
numericRound(x.asInstanceOf[Byte], _scale)
571+
BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte
591572
case ShortType =>
592-
numericRound(x.asInstanceOf[Short], _scale)
573+
BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_UP).toShort
593574
case IntegerType =>
594-
numericRound(x.asInstanceOf[Int], _scale)
575+
BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt
595576
case LongType =>
596-
numericRound(x.asInstanceOf[Long], _scale)
577+
BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong
597578
case FloatType =>
598-
numericRound(x.asInstanceOf[Float], _scale)
579+
val f = input1.asInstanceOf[Float]
580+
if (f.isNaN || f.isInfinite) {
581+
f
582+
} else {
583+
BigDecimal(f).setScale(_scale, HALF_UP).toFloat
584+
}
599585
case DoubleType =>
600-
numericRound(x.asInstanceOf[Double], _scale)
601-
}
602-
}
603-
604-
private def numericRound[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = {
605-
input match {
606-
case f: Float if (f.isNaN || f.isInfinite) => return input
607-
case d: Double if (d.isNaN || d.isInfinite) => return input
608-
case _ =>
586+
val d = input1.asInstanceOf[Double]
587+
if (d.isNaN || d.isInfinite) {
588+
d
589+
} else {
590+
BigDecimal(d).setScale(_scale, HALF_UP).toDouble
591+
}
609592
}
610-
bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP))
611593
}
612594

613595
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
614596
val ce = child.gen(ctx)
615597

616-
def round(primitive: String, integral: Boolean): String = {
617-
val (p1, p2) = if (integral) ("new", "") else ("", ".valueOf")
618-
s"""
619-
${ev.primitive} = $p1 java.math.BigDecimal$p2(${primitive}).
620-
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP)"""
621-
}
622-
623-
def fractionalCheck(primitive: String, function: String): String = {
624-
s"""
625-
if (Double.isNaN(${primitive}) || Double.isInfinite(${primitive})){
626-
${ev.primitive} = ${primitive};
627-
} else {
628-
${round(primitive, false)}.${function};
629-
}"""
630-
}
631-
632-
def decimalRound(): String = {
633-
s"""
634-
if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) {
635-
${ev.primitive} = ${ce.primitive};
636-
} else {
637-
${ev.isNull} = true;
638-
}
639-
"""
640-
}
641-
642-
val roundCode = child.dataType match {
643-
case NullType => ";"
598+
val evaluationCode = child.dataType match {
644599
case _: DecimalType =>
645-
decimalRound()
600+
s"""
601+
if (${ce.primitive}.changePrecision(${ce.primitive}.precision(), ${_scale})) {
602+
${ev.primitive} = ${ce.primitive};
603+
} else {
604+
${ev.isNull} = true;
605+
}"""
646606
case ByteType =>
647-
round(ce.primitive, true) + ".byteValue();"
607+
s"""
608+
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
609+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();"""
648610
case ShortType =>
649-
round(ce.primitive, true) + ".shortValue();"
611+
s"""
612+
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
613+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();"""
650614
case IntegerType =>
651-
round(ce.primitive, true) + ".intValue();"
615+
s"""
616+
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
617+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();"""
652618
case LongType =>
653-
round(ce.primitive, true) + ".longValue();"
619+
s"""
620+
${ev.primitive} = new java.math.BigDecimal(${ce.primitive}).
621+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();"""
654622
case FloatType =>
655-
fractionalCheck(ce.primitive, "floatValue()")
623+
s"""
624+
if (Float.isNaN(${ce.primitive}) || Float.isInfinite(${ce.primitive})){
625+
${ev.primitive} = ${ce.primitive};
626+
} else {
627+
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
628+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
629+
}"""
656630
case DoubleType =>
657-
fractionalCheck(ce.primitive, "doubleValue()")
631+
s"""
632+
if (Double.isNaN(${ce.primitive}) || Double.isInfinite(${ce.primitive})){
633+
${ev.primitive} = ${ce.primitive};
634+
} else {
635+
${ev.primitive} = java.math.BigDecimal.valueOf(${ce.primitive}).
636+
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
637+
}"""
658638
}
659639

660640
ce.code + s"""
661641
boolean ${ev.isNull} = ${ce.isNull} || ${scaleV == null};
662642
${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
663643
if (!${ev.isNull}) {
664-
${roundCode}
644+
${evaluationCode}
665645
}
666646
"""
667647
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BigDecimalConverter.scala

Lines changed: 0 additions & 60 deletions
This file was deleted.

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,13 +181,11 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
181181

182182
test("check types for ROUND") {
183183
assertErrorWithImplicitCast(Round(Literal(null), 'booleanField),
184-
"Only foldable Integral Expression is allowed for ROUND scale arguments")
184+
"data type mismatch: argument 2 is expected to be of type int")
185185
assertErrorWithImplicitCast(Round(Literal(null), 'complexField),
186-
"Only foldable Integral Expression is allowed for ROUND scale arguments")
186+
"data type mismatch: argument 2 is expected to be of type int")
187187
assertSuccess(Round(Literal(null), Literal(null)))
188188
assertError(Round('booleanField, 'intField),
189-
"Only numeric type is allowed for ROUND function")
190-
assertErrorWithImplicitCast(Round(Literal(null), Literal(1L + Int.MaxValue)),
191-
"ROUND scale argument out of allowed range")
189+
"data type mismatch: argument 1 is expected to be of type numeric")
192190
}
193191
}

0 commit comments

Comments
 (0)