@@ -523,6 +523,20 @@ case class Logarithm(left: Expression, right: Expression)
523523 }
524524}
525525
526+ /**
527+ * Round the `child`'s result to `scale` decimal place when `scale` >= 0
528+ * or round at integral part when `scale` < 0.
529+ * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30.
530+ *
531+ * Child of IntegralType would eval to itself when `scale` >= 0.
532+ * Child of FractionalType whose value is NaN or Infinite would always eval to itself.
533+ *
534+ * Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed ]],
535+ * which leads to scale update in DecimalType's [[PrecisionInfo ]]
536+ *
537+ * @param child expr to be round, all [[NumericType ]] is allowed as Input
538+ * @param scale new scale to be round to, this should be a constant int at runtime
539+ */
526540case class Round (child : Expression , scale : Expression )
527541 extends BinaryExpression with ExpectsInputTypes {
528542
@@ -559,10 +573,27 @@ case class Round(child: Expression, scale: Expression)
559573 }
560574 }
561575
562- private lazy val scaleV = scale.eval(EmptyRow )
563- private lazy val _scale = if (scaleV != null ) scaleV.asInstanceOf [Int ] else 0
576+ // Avoid repeated evaluation since `scale` is a constant int,
577+ // avoid unnecessary `child` evaluation in both codegen and non-codegen eval
578+ // by checking if scaleV == null as well.
579+ private lazy val scaleV : Any = scale.eval(EmptyRow )
580+ private lazy val _scale : Int = scaleV.asInstanceOf [Int ]
564581
565- protected override def nullSafeEval (input1 : Any , input2 : Any ): Any = {
582+ override def eval (input : InternalRow ): Any = {
583+ if (scaleV == null ) { // if scale is null, no need to eval its child at all
584+ null
585+ } else {
586+ val evalE = child.eval(input)
587+ if (evalE == null ) {
588+ null
589+ } else {
590+ nullSafeEval(evalE)
591+ }
592+ }
593+ }
594+
595+ // not overriding since _scale is a constant int at runtime
596+ def nullSafeEval (input1 : Any ): Any = {
566597 child.dataType match {
567598 case _ : DecimalType =>
568599 val decimal = input1.asInstanceOf [Decimal ]
@@ -604,45 +635,89 @@ case class Round(child: Expression, scale: Expression)
604635 ${ev.isNull} = true;
605636 } """
606637 case ByteType =>
607- s """
608- ${ev.primitive} = new java.math.BigDecimal( ${ce.primitive}).
609- setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue(); """
638+ if (_scale < 0 ) {
639+ s """
640+ ${ev.primitive} = new java.math.BigDecimal( ${ce.primitive}).
641+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue(); """
642+ } else {
643+ s " ${ev.primitive} = ${ce.primitive}; "
644+ }
610645 case ShortType =>
611- s """
612- ${ev.primitive} = new java.math.BigDecimal( ${ce.primitive}).
613- setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue(); """
646+ if (_scale < 0 ) {
647+ s """
648+ ${ev.primitive} = new java.math.BigDecimal( ${ce.primitive}).
649+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue(); """
650+ } else {
651+ s " ${ev.primitive} = ${ce.primitive}; "
652+ }
614653 case IntegerType =>
615- s """
616- ${ev.primitive} = new java.math.BigDecimal( ${ce.primitive}).
617- setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue(); """
654+ if (_scale < 0 ) {
655+ s """
656+ ${ev.primitive} = new java.math.BigDecimal( ${ce.primitive}).
657+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue(); """
658+ } else {
659+ s " ${ev.primitive} = ${ce.primitive}; "
660+ }
618661 case LongType =>
619- s """
620- ${ev.primitive} = new java.math.BigDecimal( ${ce.primitive}).
621- setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue(); """
622- case FloatType =>
623- s """
624- if (Float.isNaN( ${ce.primitive}) || Float.isInfinite( ${ce.primitive})){
625- ${ev.primitive} = ${ce.primitive};
662+ if (_scale < 0 ) {
663+ s """
664+ ${ev.primitive} = new java.math.BigDecimal( ${ce.primitive}).
665+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue(); """
626666 } else {
627- ${ev.primitive} = java.math.BigDecimal.valueOf( ${ce.primitive}).
628- setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
629- } """
630- case DoubleType =>
631- s """
632- if (Double.isNaN( ${ce.primitive}) || Double.isInfinite( ${ce.primitive})){
633- ${ev.primitive} = ${ce.primitive};
667+ s " ${ev.primitive} = ${ce.primitive}; "
668+ }
669+ case FloatType => // if child eval to NaN or Infinity, just return it.
670+ if (_scale == 0 ) {
671+ s """
672+ if (Float.isNaN( ${ce.primitive}) || Float.isInfinite( ${ce.primitive})){
673+ ${ev.primitive} = ${ce.primitive};
674+ } else {
675+ ${ev.primitive} = Math.round( ${ce.primitive});
676+ } """
634677 } else {
635- ${ev.primitive} = java.math.BigDecimal.valueOf( ${ce.primitive}).
636- setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
637- } """
678+ s """
679+ if (Float.isNaN( ${ce.primitive}) || Float.isInfinite( ${ce.primitive})){
680+ ${ev.primitive} = ${ce.primitive};
681+ } else {
682+ ${ev.primitive} = java.math.BigDecimal.valueOf( ${ce.primitive}).
683+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
684+ } """
685+ }
686+ case DoubleType => // if child eval to NaN or Infinity, just return it.
687+ if (_scale == 0 ) {
688+ s """
689+ if (Double.isNaN( ${ce.primitive}) || Double.isInfinite( ${ce.primitive})){
690+ ${ev.primitive} = ${ce.primitive};
691+ } else {
692+ ${ev.primitive} = Math.round( ${ce.primitive});
693+ } """
694+ } else {
695+ s """
696+ if (Double.isNaN( ${ce.primitive}) || Double.isInfinite( ${ce.primitive})){
697+ ${ev.primitive} = ${ce.primitive};
698+ } else {
699+ ${ev.primitive} = java.math.BigDecimal.valueOf( ${ce.primitive}).
700+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
701+ } """
702+ }
638703 }
639704
640- ce.code + s """
641- boolean ${ev.isNull} = ${ce.isNull} || ${scaleV == null };
642- ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
643- if (! ${ev.isNull}) {
644- ${evaluationCode}
645- }
705+ if (scaleV == null ) { // if scale is null, no need to eval its child at all
706+ s """
707+ boolean ${ev.isNull} = true;
708+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
709+ """
710+ } else {
711+ s """
712+ ${ce.code}
713+ boolean ${ev.isNull} = ${ce.isNull};
714+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
715+ if (! ${ev.isNull}) {
716+ $evaluationCode
717+ }
646718 """
719+ }
647720 }
721+
722+ override def prettyName : String = " round"
648723}
0 commit comments