Skip to content

Commit e6f44c4

Browse files
committed
refactor eval and genCode
1 parent 1b87540 commit e6f44c4

File tree

2 files changed

+50
-52
lines changed

2 files changed

+50
-52
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 48 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -404,11 +404,7 @@ case class Atan2(left: Expression, right: Expression)
404404
case class Pow(left: Expression, right: Expression)
405405
extends BinaryMathExpression(math.pow, "POWER") {
406406
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
407-
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") + s"""
408-
if (Double.valueOf(${ev.primitive}).isNaN()) {
409-
${ev.isNull} = true;
410-
}
411-
"""
407+
defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)")
412408
}
413409
}
414410

@@ -530,20 +526,20 @@ case class Round(child: Expression, scale: Expression) extends Expression {
530526
this(child, Literal(0))
531527
}
532528

533-
def children: Seq[Expression] = Seq(child, scale)
529+
override def children: Seq[Expression] = Seq(child, scale)
530+
531+
override def nullable: Boolean = true
534532

535-
def nullable: Boolean = true
533+
override def foldable: Boolean = child.foldable
536534

537-
private lazy val scaleV = scale.asInstanceOf[Literal].value
535+
private lazy val scaleV = scale.eval(EmptyRow)
538536
private lazy val _scale = if (scaleV != null) scaleV.asInstanceOf[Int] else 0
539537

540-
override lazy val dataType: DataType = {
541-
child.dataType match {
538+
override lazy val dataType: DataType = child.dataType match {
542539
case StringType | BinaryType => DoubleType
543540
case DecimalType.Fixed(p, s) => DecimalType(p, _scale)
544541
case t => t
545542
}
546-
}
547543

548544
override def checkInputDataTypes(): TypeCheckResult = {
549545
child.dataType match {
@@ -557,41 +553,42 @@ case class Round(child: Expression, scale: Expression) extends Expression {
557553
if (value.asInstanceOf[Long] < Int.MinValue || value.asInstanceOf[Long] > Int.MaxValue) {
558554
return TypeCheckFailure("ROUND scale argument out of allowed range")
559555
}
560-
case Literal(_, _: IntegralType) | Literal(_, NullType) => // satisfy requirement
561556
case _ =>
562-
if (!scale.foldable) {
563-
return TypeCheckFailure("Only Integral Literal or Null Literal " +
564-
s"are allowed for ROUND scale arguments, got ${child.dataType}")
557+
if ((scale.dataType.isInstanceOf[IntegralType] || scale.dataType == NullType) &&
558+
scale.foldable) {
559+
// TODO: foldable LongType is not checked for out of range
560+
// satisfy requirement
561+
} else {
562+
return TypeCheckFailure("Only Integral or Null foldable Expression " +
563+
s"is allowed for ROUND scale arguments, got ${child.dataType}")
565564
}
566565
}
567566
TypeCheckSuccess
568567
}
569568

570-
def eval(input: InternalRow): Any = {
571-
val evalE = child.eval(input)
569+
private lazy val rounding: (Any) => (Any) = roundGen(child.dataType)
572570

573-
if (evalE == null || scaleV == null) return null
574-
575-
child.dataType match {
571+
def roundGen(dt: DataType)(x: Any): Any = {
572+
dt match {
576573
case _: DecimalType =>
577-
val decimal = evalE.asInstanceOf[Decimal]
574+
val decimal = x.asInstanceOf[Decimal]
578575
if (decimal.changePrecision(decimal.precision, _scale)) decimal else null
579576
case ByteType =>
580-
round(evalE.asInstanceOf[Byte], _scale)
577+
round(x.asInstanceOf[Byte], _scale)
581578
case ShortType =>
582-
round(evalE.asInstanceOf[Short], _scale)
579+
round(x.asInstanceOf[Short], _scale)
583580
case IntegerType =>
584-
round(evalE.asInstanceOf[Int], _scale)
581+
round(x.asInstanceOf[Int], _scale)
585582
case LongType =>
586-
round(evalE.asInstanceOf[Long], _scale)
583+
round(x.asInstanceOf[Long], _scale)
587584
case FloatType =>
588-
round(evalE.asInstanceOf[Float], _scale)
585+
round(x.asInstanceOf[Float], _scale)
589586
case DoubleType =>
590-
round(evalE.asInstanceOf[Double], _scale)
587+
round(x.asInstanceOf[Double], _scale)
591588
case StringType =>
592-
round(evalE.asInstanceOf[UTF8String].toString, _scale)
589+
round(x.asInstanceOf[UTF8String].toString, _scale)
593590
case BinaryType =>
594-
round(UTF8String.fromBytes(evalE.asInstanceOf[Array[Byte]]).toString, _scale)
591+
round(UTF8String.fromBytes(x.asInstanceOf[Array[Byte]]).toString, _scale)
595592
}
596593
}
597594

@@ -606,35 +603,36 @@ case class Round(child: Expression, scale: Expression) extends Expression {
606603

607604
private def round(input: String, scale: Int): Any = {
608605
try round(input.toDouble, scale) catch {
609-
case _ : NumberFormatException => null
606+
case _: NumberFormatException => null
610607
}
611608
}
612609

610+
def eval(input: InternalRow): Any = {
611+
val evalE = child.eval(input)
612+
if (evalE == null || scaleV == null) return null
613+
rounding(evalE)
614+
}
615+
613616
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
614617
val ce = child.gen(ctx)
615618

616-
def integralRound(primitive: String): String = {
617-
s"""
618-
${ev.primitive} = new java.math.BigDecimal(${primitive}).
619-
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP)"""
620-
}
621-
622-
def fractionalRound(primitive: String): String = {
619+
def round(primitive: String, integral: Boolean): String = {
620+
val (p1, p2) = if (integral) ("new", "") else ("", ".valueOf")
623621
s"""
624-
${ev.primitive} = java.math.BigDecimal.valueOf(${primitive}).
622+
${ev.primitive} = $p1 java.math.BigDecimal$p2(${primitive}).
625623
setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP)"""
626624
}
627625

628-
def check(primitive: String, function: String): String = {
626+
def fractionalCheck(primitive: String, function: String): String = {
629627
s"""
630628
if (Double.isNaN(${primitive}) || Double.isInfinite(${primitive})){
631629
${ev.primitive} = ${primitive};
632630
} else {
633-
${fractionalRound(primitive)}.${function};
631+
${round(primitive, false)}.${function};
634632
}"""
635633
}
636634

637-
def convert(primitive: String): String = {
635+
def stringLikeConvert(primitive: String): String = {
638636
val dName = ctx.freshName("converter")
639637
s"""
640638
Double $dName = 0.0;
@@ -643,7 +641,7 @@ case class Round(child: Expression, scale: Expression) extends Expression {
643641
} catch (NumberFormatException e) {
644642
${ev.isNull} = true;
645643
}
646-
${check(dName, "doubleValue()")}
644+
${fractionalCheck(dName, "doubleValue()")}
647645
"""
648646
}
649647

@@ -662,21 +660,21 @@ case class Round(child: Expression, scale: Expression) extends Expression {
662660
case _: DecimalType =>
663661
decimalRound()
664662
case ByteType =>
665-
integralRound(ce.primitive) + ".byteValue();"
663+
round(ce.primitive, true) + ".byteValue();"
666664
case ShortType =>
667-
integralRound(ce.primitive) + ".shortValue();"
665+
round(ce.primitive, true) + ".shortValue();"
668666
case IntegerType =>
669-
integralRound(ce.primitive) + ".intValue();"
667+
round(ce.primitive, true) + ".intValue();"
670668
case LongType =>
671-
integralRound(ce.primitive) + ".longValue();"
669+
round(ce.primitive, true) + ".longValue();"
672670
case FloatType =>
673-
check(ce.primitive, "floatValue()")
671+
fractionalCheck(ce.primitive, "floatValue()")
674672
case DoubleType =>
675-
check(ce.primitive, "doubleValue()")
673+
fractionalCheck(ce.primitive, "doubleValue()")
676674
case StringType =>
677-
convert(ce.primitive)
675+
stringLikeConvert(ce.primitive)
678676
case BinaryType =>
679-
convert(s"${ctx.stringType}.fromBytes(${ce.primitive})")
677+
stringLikeConvert(s"${ctx.stringType}.fromBytes(${ce.primitive})")
680678
}
681679

682680
ce.code + s"""

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,9 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
174174

175175
test("check types for ROUND") {
176176
assertError(Round(Literal(null), 'booleanField),
177-
"Only Integral Literal or Null Literal are allowed for ROUND scale argument")
177+
"Only Integral or Null foldable Expression is allowed for ROUND scale argument")
178178
assertError(Round(Literal(null), 'complexField),
179-
"Only Integral Literal or Null Literal are allowed for ROUND scale argument")
179+
"Only Integral or Null foldable Expression is allowed for ROUND scale argument")
180180
assertSuccess(Round(Literal(null), Literal(null)))
181181
assertError(Round('booleanField, 'intField),
182182
"Only numeric, string or binary data types are allowed for ROUND function")

0 commit comments

Comments
 (0)