@@ -404,11 +404,7 @@ case class Atan2(left: Expression, right: Expression)
404404case class Pow (left : Expression , right : Expression )
405405 extends BinaryMathExpression (math.pow, " POWER" ) {
406406 override def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = {
407- defineCodeGen(ctx, ev, (c1, c2) => s " java.lang.Math.pow( $c1, $c2) " ) + s """
408- if (Double.valueOf( ${ev.primitive}).isNaN()) {
409- ${ev.isNull} = true;
410- }
411- """
407+ defineCodeGen(ctx, ev, (c1, c2) => s " java.lang.Math.pow( $c1, $c2) " )
412408 }
413409}
414410
@@ -530,20 +526,20 @@ case class Round(child: Expression, scale: Expression) extends Expression {
530526 this (child, Literal (0 ))
531527 }
532528
533- def children : Seq [Expression ] = Seq (child, scale)
529+ override def children : Seq [Expression ] = Seq (child, scale)
530+
531+ override def nullable : Boolean = true
534532
535- def nullable : Boolean = true
533+ override def foldable : Boolean = child.foldable
536534
537- private lazy val scaleV = scale.asInstanceOf [ Literal ].value
535+ private lazy val scaleV = scale.eval( EmptyRow )
538536 private lazy val _scale = if (scaleV != null ) scaleV.asInstanceOf [Int ] else 0
539537
540- override lazy val dataType : DataType = {
541- child.dataType match {
538+ override lazy val dataType : DataType = child.dataType match {
542539 case StringType | BinaryType => DoubleType
543540 case DecimalType .Fixed (p, s) => DecimalType (p, _scale)
544541 case t => t
545542 }
546- }
547543
548544 override def checkInputDataTypes (): TypeCheckResult = {
549545 child.dataType match {
@@ -557,41 +553,42 @@ case class Round(child: Expression, scale: Expression) extends Expression {
557553 if (value.asInstanceOf [Long ] < Int .MinValue || value.asInstanceOf [Long ] > Int .MaxValue ) {
558554 return TypeCheckFailure (" ROUND scale argument out of allowed range" )
559555 }
560- case Literal (_, _ : IntegralType ) | Literal (_, NullType ) => // satisfy requirement
561556 case _ =>
562- if (! scale.foldable) {
563- return TypeCheckFailure (" Only Integral Literal or Null Literal " +
564- s " are allowed for ROUND scale arguments, got ${child.dataType}" )
557+ if ((scale.dataType.isInstanceOf [IntegralType ] || scale.dataType == NullType ) &&
558+ scale.foldable) {
559+ // TODO: foldable LongType is not checked for out of range
560+ // satisfy requirement
561+ } else {
562+ return TypeCheckFailure (" Only Integral or Null foldable Expression " +
563+ s " is allowed for ROUND scale arguments, got ${child.dataType}" )
565564 }
566565 }
567566 TypeCheckSuccess
568567 }
569568
570- def eval (input : InternalRow ): Any = {
571- val evalE = child.eval(input)
569+ private lazy val rounding : (Any ) => (Any ) = roundGen(child.dataType)
572570
573- if (evalE == null || scaleV == null ) return null
574-
575- child.dataType match {
571+ def roundGen (dt : DataType )(x : Any ): Any = {
572+ dt match {
576573 case _ : DecimalType =>
577- val decimal = evalE .asInstanceOf [Decimal ]
574+ val decimal = x .asInstanceOf [Decimal ]
578575 if (decimal.changePrecision(decimal.precision, _scale)) decimal else null
579576 case ByteType =>
580- round(evalE .asInstanceOf [Byte ], _scale)
577+ round(x .asInstanceOf [Byte ], _scale)
581578 case ShortType =>
582- round(evalE .asInstanceOf [Short ], _scale)
579+ round(x .asInstanceOf [Short ], _scale)
583580 case IntegerType =>
584- round(evalE .asInstanceOf [Int ], _scale)
581+ round(x .asInstanceOf [Int ], _scale)
585582 case LongType =>
586- round(evalE .asInstanceOf [Long ], _scale)
583+ round(x .asInstanceOf [Long ], _scale)
587584 case FloatType =>
588- round(evalE .asInstanceOf [Float ], _scale)
585+ round(x .asInstanceOf [Float ], _scale)
589586 case DoubleType =>
590- round(evalE .asInstanceOf [Double ], _scale)
587+ round(x .asInstanceOf [Double ], _scale)
591588 case StringType =>
592- round(evalE .asInstanceOf [UTF8String ].toString, _scale)
589+ round(x .asInstanceOf [UTF8String ].toString, _scale)
593590 case BinaryType =>
594- round(UTF8String .fromBytes(evalE .asInstanceOf [Array [Byte ]]).toString, _scale)
591+ round(UTF8String .fromBytes(x .asInstanceOf [Array [Byte ]]).toString, _scale)
595592 }
596593 }
597594
@@ -606,35 +603,36 @@ case class Round(child: Expression, scale: Expression) extends Expression {
606603
607604 private def round (input : String , scale : Int ): Any = {
608605 try round(input.toDouble, scale) catch {
609- case _ : NumberFormatException => null
606+ case _ : NumberFormatException => null
610607 }
611608 }
612609
610+ def eval (input : InternalRow ): Any = {
611+ val evalE = child.eval(input)
612+ if (evalE == null || scaleV == null ) return null
613+ rounding(evalE)
614+ }
615+
613616 override def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = {
614617 val ce = child.gen(ctx)
615618
616- def integralRound (primitive : String ): String = {
617- s """
618- ${ev.primitive} = new java.math.BigDecimal( ${primitive}).
619- setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP) """
620- }
621-
622- def fractionalRound (primitive : String ): String = {
619+ def round (primitive : String , integral : Boolean ): String = {
620+ val (p1, p2) = if (integral) (" new" , " " ) else (" " , " .valueOf" )
623621 s """
624- ${ev.primitive} = java.math.BigDecimal.valueOf ( ${primitive}).
622+ ${ev.primitive} = $p1 java.math.BigDecimal$p2 ( ${primitive}).
625623 setScale( ${_scale}, java.math.BigDecimal.ROUND_HALF_UP) """
626624 }
627625
628- def check (primitive : String , function : String ): String = {
626+ def fractionalCheck (primitive : String , function : String ): String = {
629627 s """
630628 if (Double.isNaN( ${primitive}) || Double.isInfinite( ${primitive})){
631629 ${ev.primitive} = ${primitive};
632630 } else {
633- ${fractionalRound (primitive)}. ${function};
631+ ${round (primitive, false )}. ${function};
634632 } """
635633 }
636634
637- def convert (primitive : String ): String = {
635+ def stringLikeConvert (primitive : String ): String = {
638636 val dName = ctx.freshName(" converter" )
639637 s """
640638 Double $dName = 0.0;
@@ -643,7 +641,7 @@ case class Round(child: Expression, scale: Expression) extends Expression {
643641 } catch (NumberFormatException e) {
644642 ${ev.isNull} = true;
645643 }
646- ${check (dName, " doubleValue()" )}
644+ ${fractionalCheck (dName, " doubleValue()" )}
647645 """
648646 }
649647
@@ -662,21 +660,21 @@ case class Round(child: Expression, scale: Expression) extends Expression {
662660 case _ : DecimalType =>
663661 decimalRound()
664662 case ByteType =>
665- integralRound (ce.primitive) + " .byteValue();"
663+ round (ce.primitive, true ) + " .byteValue();"
666664 case ShortType =>
667- integralRound (ce.primitive) + " .shortValue();"
665+ round (ce.primitive, true ) + " .shortValue();"
668666 case IntegerType =>
669- integralRound (ce.primitive) + " .intValue();"
667+ round (ce.primitive, true ) + " .intValue();"
670668 case LongType =>
671- integralRound (ce.primitive) + " .longValue();"
669+ round (ce.primitive, true ) + " .longValue();"
672670 case FloatType =>
673- check (ce.primitive, " floatValue()" )
671+ fractionalCheck (ce.primitive, " floatValue()" )
674672 case DoubleType =>
675- check (ce.primitive, " doubleValue()" )
673+ fractionalCheck (ce.primitive, " doubleValue()" )
676674 case StringType =>
677- convert (ce.primitive)
675+ stringLikeConvert (ce.primitive)
678676 case BinaryType =>
679- convert (s " ${ctx.stringType}.fromBytes( ${ce.primitive}) " )
677+ stringLikeConvert (s " ${ctx.stringType}.fromBytes( ${ce.primitive}) " )
680678 }
681679
682680 ce.code + s """
0 commit comments