Skip to content

Commit 1b87540

Browse files
committed
modify checkInputDataTypes using foldable
1 parent 5486b2d commit 1b87540

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -558,8 +558,8 @@ case class Round(child: Expression, scale: Expression) extends Expression {
558558
return TypeCheckFailure("ROUND scale argument out of allowed range")
559559
}
560560
case Literal(_, _: IntegralType) | Literal(_, NullType) => // satisfy requirement
561-
case child =>
562-
if (child.find { case _: AttributeReference => true; case _ => false } != None) {
561+
case _ =>
562+
if (!scale.foldable) {
563563
return TypeCheckFailure("Only Integral Literal or Null Literal " +
564564
s"are allowed for ROUND scale arguments, got ${child.dataType}")
565565
}
@@ -595,6 +595,21 @@ case class Round(child: Expression, scale: Expression) extends Expression {
595595
}
596596
}
597597

598+
private def round[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = {
599+
input match {
600+
case f: Float if (f.isNaN || f.isInfinite) => return input
601+
case d: Double if (d.isNaN || d.isInfinite) => return input
602+
case _ =>
603+
}
604+
bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP))
605+
}
606+
607+
private def round(input: String, scale: Int): Any = {
608+
try round(input.toDouble, scale) catch {
609+
case _ : NumberFormatException => null
610+
}
611+
}
612+
598613
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
599614
val ce = child.gen(ctx)
600615

@@ -672,19 +687,4 @@ case class Round(child: Expression, scale: Expression) extends Expression {
672687
}
673688
"""
674689
}
675-
676-
private def round[T](input: T, scale: Int)(implicit bdc: BigDecimalConverter[T]): T = {
677-
input match {
678-
case f: Float if (f.isNaN || f.isInfinite) => return input
679-
case d: Double if (d.isNaN || d.isInfinite) => return input
680-
case _ =>
681-
}
682-
bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal.RoundingMode.HALF_UP))
683-
}
684-
685-
private def round(input: String, scale: Int): Any = {
686-
try round(input.toDouble, scale) catch {
687-
case _ : NumberFormatException => null
688-
}
689-
}
690690
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
339339
create_row(null))
340340
}
341341

342-
test("round test") {
342+
test("round") {
343343
val domain = -16 to 16
344344
val doublePi = math.Pi
345345
val stringPi = "3.141592653589793"

0 commit comments

Comments
 (0)