@@ -23,7 +23,6 @@ import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2323import org .apache .spark .sql .catalyst .analysis .TypeCheckResult .{TypeCheckSuccess , TypeCheckFailure }
2424import org .apache .spark .sql .catalyst .expressions .codegen ._
2525import org .apache .spark .sql .catalyst .InternalRow
26- import org .apache .spark .sql .catalyst .util .BigDecimalConverter
2726import org .apache .spark .sql .types ._
2827import org .apache .spark .unsafe .types .UTF8String
2928
@@ -524,144 +523,125 @@ case class Logarithm(left: Expression, right: Expression)
524523 }
525524}
526525
527- case class Round (child : Expression , scale : Expression ) extends Expression with ExpectsInputTypes {
526+ case class Round (child : Expression , scale : Expression )
527+ extends BinaryExpression with ExpectsInputTypes {
528528
529- def this (child : Expression ) = {
530- this (child, Literal (0 ))
531- }
529+ import BigDecimal .RoundingMode .HALF_UP
530+
531+ def this (child : Expression ) = this (child, Literal (0 ))
532+
533+ override def left : Expression = child
534+ override def right : Expression = scale
532535
533536 override def children : Seq [Expression ] = Seq (child, scale)
534537
538+ // round of Decimal would eval to null if it fails to `changePrecision`
535539 override def nullable : Boolean = true
536540
537541 override def foldable : Boolean = child.foldable
538542
539543 override lazy val dataType : DataType = child.dataType match {
540- case DecimalType .Fixed (p, s) => DecimalType (p, _scale)
541- case t => t
542- }
544+ case DecimalType .Fixed (p, s) => DecimalType (p, _scale)
545+ case t => t
546+ }
543547
544- override def inputTypes : Seq [AbstractDataType ] = Seq (
545- // rely on precedence to implicit cast String into Double
546- TypeCollection (DecimalType , DoubleType , FloatType , LongType , IntegerType , ShortType , ByteType ),
547- TypeCollection (LongType , IntegerType , ShortType , ByteType ))
548+ override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType , IntegerType )
548549
549550 override def checkInputDataTypes (): TypeCheckResult = {
550- child.dataType match {
551- case _ : NumericType => // satisfy requirement
552- case dt =>
553- return TypeCheckFailure (s " Only numeric type is allowed for ROUND function, got $dt" )
554- }
555- scale match {
556- case Literal (value, LongType ) =>
557- if (value.asInstanceOf [Long ] < Int .MinValue || value.asInstanceOf [Long ] > Int .MaxValue ) {
558- return TypeCheckFailure (" ROUND scale argument out of allowed range" )
559- }
560- case _ =>
561- if (scale.dataType.isInstanceOf [IntegralType ] && scale.foldable) {
562- // TODO: How to check out of range for foldable LongType Expression
563- // satisfy requirement
551+ super .checkInputDataTypes() match {
552+ case TypeCheckSuccess =>
553+ if (scale.foldable) {
554+ TypeCheckSuccess
564555 } else {
565- return TypeCheckFailure (" Only foldable Integral Expression " +
566- s " is allowed for ROUND scale arguments, got ${child.dataType}" )
556+ TypeCheckFailure (" Only foldable Expression is allowed for scale arguments" )
567557 }
558+ case f => f
568559 }
569- TypeCheckSuccess
570560 }
571561
572562 private lazy val scaleV = scale.eval(EmptyRow )
573563 private lazy val _scale = if (scaleV != null ) scaleV.asInstanceOf [Int ] else 0
574564
575- override def eval (input : InternalRow ): Any = {
576- val evalE = child.eval(input)
577- if (evalE == null || scaleV == null ) return null
578- round(evalE)
579- }
580-
581- private lazy val round : (Any ) => (Any ) = typedRound(child.dataType)
582-
583- // Using dataType info to find an appropriate round method
584- private def typedRound (dt : DataType )(x : Any ): Any = {
585- dt match {
565+ protected override def nullSafeEval (input1 : Any , input2 : Any ): Any = {
566+ child.dataType match {
586567 case _ : DecimalType =>
587- val decimal = x .asInstanceOf [Decimal ]
568+ val decimal = input1 .asInstanceOf [Decimal ]
588569 if (decimal.changePrecision(decimal.precision, _scale)) decimal else null
589570 case ByteType =>
590- numericRound(x .asInstanceOf [Byte ], _scale)
571+ BigDecimal (input1 .asInstanceOf [Byte ]).setScale(_scale, HALF_UP ).toByte
591572 case ShortType =>
592- numericRound(x .asInstanceOf [Short ], _scale)
573+ BigDecimal (input1 .asInstanceOf [Short ]).setScale(_scale, HALF_UP ).toShort
593574 case IntegerType =>
594- numericRound(x .asInstanceOf [Int ], _scale)
575+ BigDecimal (input1 .asInstanceOf [Int ]).setScale(_scale, HALF_UP ).toInt
595576 case LongType =>
596- numericRound(x .asInstanceOf [Long ], _scale)
577+ BigDecimal (input1 .asInstanceOf [Long ]).setScale(_scale, HALF_UP ).toLong
597578 case FloatType =>
598- numericRound(x.asInstanceOf [Float ], _scale)
579+ val f = input1.asInstanceOf [Float ]
580+ if (f.isNaN || f.isInfinite) {
581+ f
582+ } else {
583+ BigDecimal (f).setScale(_scale, HALF_UP ).toFloat
584+ }
599585 case DoubleType =>
600- numericRound(x.asInstanceOf [Double ], _scale)
601- }
602- }
603-
604- private def numericRound [T ](input : T , scale : Int )(implicit bdc : BigDecimalConverter [T ]): T = {
605- input match {
606- case f : Float if (f.isNaN || f.isInfinite) => return input
607- case d : Double if (d.isNaN || d.isInfinite) => return input
608- case _ =>
586+ val d = input1.asInstanceOf [Double ]
587+ if (d.isNaN || d.isInfinite) {
588+ d
589+ } else {
590+ BigDecimal (d).setScale(_scale, HALF_UP ).toDouble
591+ }
609592 }
610- bdc.fromBigDecimal(bdc.toBigDecimal(input).setScale(scale, BigDecimal .RoundingMode .HALF_UP ))
611593 }
612594
613595 override def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = {
614596 val ce = child.gen(ctx)
615597
616- def round (primitive : String , integral : Boolean ): String = {
617- val (p1, p2) = if (integral) (" new" , " " ) else (" " , " .valueOf" )
618- s """
619- ${ev.primitive} = $p1 java.math.BigDecimal $p2( ${primitive}).
620- setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP) """
621- }
622-
623- def fractionalCheck (primitive : String , function : String ): String = {
624- s """
625- if (Double.isNaN( ${primitive}) || Double.isInfinite( ${primitive})){
626- ${ev.primitive} = ${primitive};
627- } else {
628- ${round(primitive, false )}. ${function};
629- } """
630- }
631-
632- def decimalRound (): String = {
633- s """
634- if ( ${ce.primitive}.changePrecision( ${ce.primitive}.precision(), ${_scale})) {
635- ${ev.primitive} = ${ce.primitive};
636- } else {
637- ${ev.isNull} = true;
638- }
639- """
640- }
641-
642- val roundCode = child.dataType match {
643- case NullType => " ;"
598+ val evaluationCode = child.dataType match {
644599 case _ : DecimalType =>
645- decimalRound()
600+ s """
601+ if ( ${ce.primitive}.changePrecision( ${ce.primitive}.precision(), ${_scale})) {
602+ ${ev.primitive} = ${ce.primitive};
603+ } else {
604+ ${ev.isNull} = true;
605+ } """
646606 case ByteType =>
647- round(ce.primitive, true ) + " .byteValue();"
607+ s """
608+ ${ev.primitive} = new java.math.BigDecimal( ${ce.primitive}).
609+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue(); """
648610 case ShortType =>
649- round(ce.primitive, true ) + " .shortValue();"
611+ s """
612+ ${ev.primitive} = new java.math.BigDecimal( ${ce.primitive}).
613+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue(); """
650614 case IntegerType =>
651- round(ce.primitive, true ) + " .intValue();"
615+ s """
616+ ${ev.primitive} = new java.math.BigDecimal( ${ce.primitive}).
617+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue(); """
652618 case LongType =>
653- round(ce.primitive, true ) + " .longValue();"
619+ s """
620+ ${ev.primitive} = new java.math.BigDecimal( ${ce.primitive}).
621+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue(); """
654622 case FloatType =>
655- fractionalCheck(ce.primitive, " floatValue()" )
623+ s """
624+ if (Float.isNaN( ${ce.primitive}) || Float.isInfinite( ${ce.primitive})){
625+ ${ev.primitive} = ${ce.primitive};
626+ } else {
627+ ${ev.primitive} = java.math.BigDecimal.valueOf( ${ce.primitive}).
628+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
629+ } """
656630 case DoubleType =>
657- fractionalCheck(ce.primitive, " doubleValue()" )
631+ s """
632+ if (Double.isNaN( ${ce.primitive}) || Double.isInfinite( ${ce.primitive})){
633+ ${ev.primitive} = ${ce.primitive};
634+ } else {
635+ ${ev.primitive} = java.math.BigDecimal.valueOf( ${ce.primitive}).
636+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
637+ } """
658638 }
659639
660640 ce.code + s """
661641 boolean ${ev.isNull} = ${ce.isNull} || ${scaleV == null };
662642 ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
663643 if (! ${ev.isNull}) {
664- ${roundCode }
644+ ${evaluationCode }
665645 }
666646 """
667647 }
0 commit comments