@@ -572,8 +572,8 @@ case class Round(child: Expression, scale: Expression) extends Expression {
572572
573573 if (evalE == null || scaleV == null ) return null
574574
575- children( 0 ) .dataType match {
576- case decimalType : DecimalType =>
575+ child .dataType match {
576+ case _ : DecimalType =>
577577 val decimal = evalE.asInstanceOf [Decimal ]
578578 if (decimal.changePrecision(decimal.precision, _scale)) decimal else null
579579 case ByteType =>
@@ -595,6 +595,84 @@ case class Round(child: Expression, scale: Expression) extends Expression {
595595 }
596596 }
597597
598+ override def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = {
599+ val ce = child.gen(ctx)
600+
601+ def integralRound (primitive : String ): String = {
602+ s """
603+ ${ev.primitive} = new java.math.BigDecimal( ${primitive}).
604+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP) """
605+ }
606+
607+ def fractionalRound (primitive : String ): String = {
608+ s """
609+ ${ev.primitive} = java.math.BigDecimal.valueOf( ${primitive}).
610+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP) """
611+ }
612+
613+ def check (primitive : String , function : String ): String = {
614+ s """
615+ if (Double.isNaN( ${primitive}) || Double.isInfinite( ${primitive})){
616+ ${ev.primitive} = ${primitive};
617+ } else {
618+ ${fractionalRound(primitive)}. ${function};
619+ } """
620+ }
621+
622+ def convert (primitive : String ): String = {
623+ val dName = ctx.freshName(" converter" )
624+ s """
625+ Double $dName = 0.0;
626+ try {
627+ $dName = Double.valueOf( ${primitive}.toString());
628+ } catch (NumberFormatException e) {
629+ ${ev.isNull} = true;
630+ }
631+ ${check(dName, " doubleValue()" )}
632+ """
633+ }
634+
635+ def decimalRound (): String = {
636+ s """
637+ if ( ${ce.primitive}.changePrecision( ${ce.primitive}.precision(), ${_scale})) {
638+ ${ev.primitive} = ${ce.primitive};
639+ } else {
640+ ${ev.isNull} = true;
641+ }
642+ """
643+ }
644+
645+ val roundCode = child.dataType match {
646+ case NullType => " ;"
647+ case _ : DecimalType =>
648+ decimalRound()
649+ case ByteType =>
650+ integralRound(ce.primitive) + " .byteValue();"
651+ case ShortType =>
652+ integralRound(ce.primitive) + " .shortValue();"
653+ case IntegerType =>
654+ integralRound(ce.primitive) + " .intValue();"
655+ case LongType =>
656+ integralRound(ce.primitive) + " .longValue();"
657+ case FloatType =>
658+ check(ce.primitive, " floatValue()" )
659+ case DoubleType =>
660+ check(ce.primitive, " doubleValue()" )
661+ case StringType =>
662+ convert(ce.primitive)
663+ case BinaryType =>
664+ convert(s " ${ctx.stringType}.fromBytes( ${ce.primitive}) " )
665+ }
666+
667+ ce.code + s """
668+ boolean ${ev.isNull} = ${ce.isNull} || ${scaleV == null };
669+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
670+ if (! ${ev.isNull}) {
671+ ${roundCode}
672+ }
673+ """
674+ }
675+
598676 private def round [T ](input : T , scale : Int )(implicit bdc : BigDecimalConverter [T ]): T = {
599677 input match {
600678 case f : Float if (f.isNaN || f.isInfinite) => return input
0 commit comments