@@ -19,8 +19,10 @@ package org.apache.spark.sql.catalyst.expressions
1919
2020import java .{lang => jl }
2121
22- import org .apache .spark .sql .catalyst .InternalRow
22+ import org .apache .spark .sql .catalyst .analysis .TypeCheckResult
23+ import org .apache .spark .sql .catalyst .analysis .TypeCheckResult .{TypeCheckSuccess , TypeCheckFailure }
2324import org .apache .spark .sql .catalyst .expressions .codegen ._
25+ import org .apache .spark .sql .catalyst .InternalRow
2426import org .apache .spark .sql .types ._
2527import org .apache .spark .unsafe .types .UTF8String
2628
@@ -520,3 +522,202 @@ case class Logarithm(left: Expression, right: Expression)
520522 """
521523 }
522524}
525+
526+ /**
527+ * Round the `child`'s result to `scale` decimal place when `scale` >= 0
528+ * or round at integral part when `scale` < 0.
529+ * For example, round(31.415, 2) would eval to 31.42 and round(31.415, -1) would eval to 30.
530+ *
531+ * Child of IntegralType would eval to itself when `scale` >= 0.
532+ * Child of FractionalType whose value is NaN or Infinite would always eval to itself.
533+ *
534+ * Round's dataType would always equal to `child`'s dataType except for [[DecimalType.Fixed ]],
535+ * which leads to scale update in DecimalType's [[PrecisionInfo ]]
536+ *
537+ * @param child expr to be round, all [[NumericType ]] is allowed as Input
538+ * @param scale new scale to be round to, this should be a constant int at runtime
539+ */
540+ case class Round (child : Expression , scale : Expression )
541+ extends BinaryExpression with ExpectsInputTypes {
542+
543+ import BigDecimal .RoundingMode .HALF_UP
544+
545+ def this (child : Expression ) = this (child, Literal (0 ))
546+
547+ override def left : Expression = child
548+ override def right : Expression = scale
549+
550+ // round of Decimal would eval to null if it fails to `changePrecision`
551+ override def nullable : Boolean = true
552+
553+ override def foldable : Boolean = child.foldable
554+
555+ override lazy val dataType : DataType = child.dataType match {
556+ // if the new scale is bigger which means we are scaling up,
557+ // keep the original scale as `Decimal` does
558+ case DecimalType .Fixed (p, s) => DecimalType (p, if (_scale > s) s else _scale)
559+ case t => t
560+ }
561+
562+ override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType , IntegerType )
563+
564+ override def checkInputDataTypes (): TypeCheckResult = {
565+ super .checkInputDataTypes() match {
566+ case TypeCheckSuccess =>
567+ if (scale.foldable) {
568+ TypeCheckSuccess
569+ } else {
570+ TypeCheckFailure (" Only foldable Expression is allowed for scale arguments" )
571+ }
572+ case f => f
573+ }
574+ }
575+
576+ // Avoid repeated evaluation since `scale` is a constant int,
577+ // avoid unnecessary `child` evaluation in both codegen and non-codegen eval
578+ // by checking if scaleV == null as well.
579+ private lazy val scaleV : Any = scale.eval(EmptyRow )
580+ private lazy val _scale : Int = scaleV.asInstanceOf [Int ]
581+
582+ override def eval (input : InternalRow ): Any = {
583+ if (scaleV == null ) { // if scale is null, no need to eval its child at all
584+ null
585+ } else {
586+ val evalE = child.eval(input)
587+ if (evalE == null ) {
588+ null
589+ } else {
590+ nullSafeEval(evalE)
591+ }
592+ }
593+ }
594+
595+ // not overriding since _scale is a constant int at runtime
596+ def nullSafeEval (input1 : Any ): Any = {
597+ child.dataType match {
598+ case _ : DecimalType =>
599+ val decimal = input1.asInstanceOf [Decimal ]
600+ if (decimal.changePrecision(decimal.precision, _scale)) decimal else null
601+ case ByteType =>
602+ BigDecimal (input1.asInstanceOf [Byte ]).setScale(_scale, HALF_UP ).toByte
603+ case ShortType =>
604+ BigDecimal (input1.asInstanceOf [Short ]).setScale(_scale, HALF_UP ).toShort
605+ case IntegerType =>
606+ BigDecimal (input1.asInstanceOf [Int ]).setScale(_scale, HALF_UP ).toInt
607+ case LongType =>
608+ BigDecimal (input1.asInstanceOf [Long ]).setScale(_scale, HALF_UP ).toLong
609+ case FloatType =>
610+ val f = input1.asInstanceOf [Float ]
611+ if (f.isNaN || f.isInfinite) {
612+ f
613+ } else {
614+ BigDecimal (f).setScale(_scale, HALF_UP ).toFloat
615+ }
616+ case DoubleType =>
617+ val d = input1.asInstanceOf [Double ]
618+ if (d.isNaN || d.isInfinite) {
619+ d
620+ } else {
621+ BigDecimal (d).setScale(_scale, HALF_UP ).toDouble
622+ }
623+ }
624+ }
625+
626+ override def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = {
627+ val ce = child.gen(ctx)
628+
629+ val evaluationCode = child.dataType match {
630+ case _ : DecimalType =>
631+ s """
632+ if ( ${ce.primitive}.changePrecision( ${ce.primitive}.precision(), ${_scale})) {
633+ ${ev.primitive} = ${ce.primitive};
634+ } else {
635+ ${ev.isNull} = true;
636+ } """
637+ case ByteType =>
638+ if (_scale < 0 ) {
639+ s """
640+ ${ev.primitive} = new java.math.BigDecimal( ${ce.primitive}).
641+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue(); """
642+ } else {
643+ s " ${ev.primitive} = ${ce.primitive}; "
644+ }
645+ case ShortType =>
646+ if (_scale < 0 ) {
647+ s """
648+ ${ev.primitive} = new java.math.BigDecimal( ${ce.primitive}).
649+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue(); """
650+ } else {
651+ s " ${ev.primitive} = ${ce.primitive}; "
652+ }
653+ case IntegerType =>
654+ if (_scale < 0 ) {
655+ s """
656+ ${ev.primitive} = new java.math.BigDecimal( ${ce.primitive}).
657+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue(); """
658+ } else {
659+ s " ${ev.primitive} = ${ce.primitive}; "
660+ }
661+ case LongType =>
662+ if (_scale < 0 ) {
663+ s """
664+ ${ev.primitive} = new java.math.BigDecimal( ${ce.primitive}).
665+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue(); """
666+ } else {
667+ s " ${ev.primitive} = ${ce.primitive}; "
668+ }
669+ case FloatType => // if child eval to NaN or Infinity, just return it.
670+ if (_scale == 0 ) {
671+ s """
672+ if (Float.isNaN( ${ce.primitive}) || Float.isInfinite( ${ce.primitive})){
673+ ${ev.primitive} = ${ce.primitive};
674+ } else {
675+ ${ev.primitive} = Math.round( ${ce.primitive});
676+ } """
677+ } else {
678+ s """
679+ if (Float.isNaN( ${ce.primitive}) || Float.isInfinite( ${ce.primitive})){
680+ ${ev.primitive} = ${ce.primitive};
681+ } else {
682+ ${ev.primitive} = java.math.BigDecimal.valueOf( ${ce.primitive}).
683+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue();
684+ } """
685+ }
686+ case DoubleType => // if child eval to NaN or Infinity, just return it.
687+ if (_scale == 0 ) {
688+ s """
689+ if (Double.isNaN( ${ce.primitive}) || Double.isInfinite( ${ce.primitive})){
690+ ${ev.primitive} = ${ce.primitive};
691+ } else {
692+ ${ev.primitive} = Math.round( ${ce.primitive});
693+ } """
694+ } else {
695+ s """
696+ if (Double.isNaN( ${ce.primitive}) || Double.isInfinite( ${ce.primitive})){
697+ ${ev.primitive} = ${ce.primitive};
698+ } else {
699+ ${ev.primitive} = java.math.BigDecimal.valueOf( ${ce.primitive}).
700+ setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue();
701+ } """
702+ }
703+ }
704+
705+ if (scaleV == null ) { // if scale is null, no need to eval its child at all
706+ s """
707+ boolean ${ev.isNull} = true;
708+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
709+ """
710+ } else {
711+ s """
712+ ${ce.code}
713+ boolean ${ev.isNull} = ${ce.isNull};
714+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
715+ if (! ${ev.isNull}) {
716+ $evaluationCode
717+ }
718+ """
719+ }
720+ }
721+
722+ override def prettyName : String = " round"
723+ }
0 commit comments