@@ -26,7 +26,7 @@ import org.apache.spark.unsafe.types.Interval
2626
2727case class UnaryMinus (child : Expression ) extends UnaryExpression with ExpectsInputTypes {
2828
29- override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType )
29+ override def inputTypes : Seq [AbstractDataType ] = Seq (TypeCollection . NumericAndInterval )
3030
3131 override def dataType : DataType = child.dataType
3232
@@ -37,15 +37,22 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
3737 override def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = dataType match {
3838 case dt : DecimalType => defineCodeGen(ctx, ev, c => s " $c.unary_ $$ minus() " )
3939 case dt : NumericType => defineCodeGen(ctx, ev, c => s " ( ${ctx.javaType(dt)})(-( $c)) " )
40+ case dt : IntervalType => defineCodeGen(ctx, ev, c => s " $c.negate() " )
4041 }
4142
42- protected override def nullSafeEval (input : Any ): Any = numeric.negate(input)
43+ protected override def nullSafeEval (input : Any ): Any = {
44+ if (dataType.isInstanceOf [IntervalType ]) {
45+ input.asInstanceOf [Interval ].negate()
46+ } else {
47+ numeric.negate(input)
48+ }
49+ }
4350}
4451
4552case class UnaryPositive (child : Expression ) extends UnaryExpression with ExpectsInputTypes {
4653 override def prettyName : String = " positive"
4754
48- override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType )
55+ override def inputTypes : Seq [AbstractDataType ] = Seq (TypeCollection . NumericAndInterval )
4956
5057 override def dataType : DataType = child.dataType
5158
@@ -85,8 +92,6 @@ abstract class BinaryArithmetic extends BinaryOperator {
8592 case ByteType | ShortType =>
8693 defineCodeGen(ctx, ev,
8794 (eval1, eval2) => s " ( ${ctx.javaType(dataType)})( $eval1 $symbol $eval2) " )
88- case IntervalType =>
89- defineCodeGen(ctx, ev, (eval1, eval2) => s """ $eval1.doOp( $eval2, " $symbol") """ )
9095 case _ =>
9196 defineCodeGen(ctx, ev, (eval1, eval2) => s " $eval1 $symbol $eval2" )
9297 }
@@ -98,8 +103,7 @@ private[sql] object BinaryArithmetic {
98103
99104case class Add (left : Expression , right : Expression ) extends BinaryArithmetic {
100105
101- override def inputType : AbstractDataType = NumericType
102- override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType , IntervalType )
106+ override def inputType : AbstractDataType = TypeCollection .NumericAndInterval
103107
104108 override def symbol : String = " +"
105109 override def decimalMethod : String = " $plus"
@@ -116,12 +120,23 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
116120 numeric.plus(input1, input2)
117121 }
118122 }
123+
124+ override def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = dataType match {
125+ case dt : DecimalType =>
126+ defineCodeGen(ctx, ev, (eval1, eval2) => s " $eval1. $decimalMethod( $eval2) " )
127+ case ByteType | ShortType =>
128+ defineCodeGen(ctx, ev,
129+ (eval1, eval2) => s " ( ${ctx.javaType(dataType)})( $eval1 $symbol $eval2) " )
130+ case IntervalType =>
131+ defineCodeGen(ctx, ev, (eval1, eval2) => s " $eval1.add( $eval2) " )
132+ case _ =>
133+ defineCodeGen(ctx, ev, (eval1, eval2) => s " $eval1 $symbol $eval2" )
134+ }
119135}
120136
121137case class Subtract (left : Expression , right : Expression ) extends BinaryArithmetic {
122138
123- override def inputType : AbstractDataType = NumericType
124- override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType , IntervalType )
139+ override def inputType : AbstractDataType = TypeCollection .NumericAndInterval
125140
126141 override def symbol : String = " -"
127142 override def decimalMethod : String = " $minus"
@@ -138,6 +153,18 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
138153 numeric.minus(input1, input2)
139154 }
140155 }
156+
157+ override def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = dataType match {
158+ case dt : DecimalType =>
159+ defineCodeGen(ctx, ev, (eval1, eval2) => s " $eval1. $decimalMethod( $eval2) " )
160+ case ByteType | ShortType =>
161+ defineCodeGen(ctx, ev,
162+ (eval1, eval2) => s " ( ${ctx.javaType(dataType)})( $eval1 $symbol $eval2) " )
163+ case IntervalType =>
164+ defineCodeGen(ctx, ev, (eval1, eval2) => s " $eval1.subtract( $eval2) " )
165+ case _ =>
166+ defineCodeGen(ctx, ev, (eval1, eval2) => s " $eval1 $symbol $eval2" )
167+ }
141168}
142169
143170case class Multiply (left : Expression , right : Expression ) extends BinaryArithmetic {
0 commit comments