@@ -29,12 +29,13 @@ abstract class UnaryArithmetic extends UnaryExpression {
2929 override def dataType : DataType = child.dataType
3030}
3131
32- case class UnaryMinus (child : Expression ) extends UnaryArithmetic with ExpectsInputTypes {
33-
34- override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType )
32+ case class UnaryMinus (child : Expression ) extends UnaryArithmetic {
3533
3634 override def toString : String = s " - $child"
3735
36+ override def checkInputDataTypes (): TypeCheckResult =
37+ TypeUtils .checkForNumericExpr(child.dataType, " operator -" )
38+
3839 private lazy val numeric = TypeUtils .getNumeric(dataType)
3940
4041 override def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = dataType match {
@@ -48,6 +49,9 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic with ExpectsInp
4849case class UnaryPositive (child : Expression ) extends UnaryArithmetic {
4950 override def prettyName : String = " positive"
5051
52+ override def checkInputDataTypes (): TypeCheckResult =
53+ TypeUtils .checkForNumericExpr(child.dataType, " operator -" )
54+
5155 override def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String =
5256 defineCodeGen(ctx, ev, c => c)
5357
@@ -57,9 +61,9 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic {
5761/**
5862 * A function that get the absolute value of the numeric value.
5963 */
60- case class Abs (child : Expression ) extends UnaryArithmetic with ExpectsInputTypes {
61-
62- override def inputTypes : Seq [ AbstractDataType ] = Seq ( NumericType )
64+ case class Abs (child : Expression ) extends UnaryArithmetic {
65+ override def checkInputDataTypes () : TypeCheckResult =
66+ TypeUtils .checkForNumericExpr(child.dataType, " function abs " )
6367
6468 private lazy val numeric = TypeUtils .getNumeric(dataType)
6569
@@ -91,14 +95,13 @@ private[sql] object BinaryArithmetic {
9195 def unapply (e : BinaryArithmetic ): Option [(Expression , Expression )] = Some ((e.left, e.right))
9296}
9397
94- case class Add (left : Expression , right : Expression )
95- extends BinaryArithmetic with ExpectsInputTypes {
98+ case class Add (left : Expression , right : Expression ) extends BinaryArithmetic {
99+
100+ override def inputType : AbstractDataType = NumericType
96101
97102 override def symbol : String = " +"
98103 override def decimalMethod : String = " $plus"
99104
100- override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType , NumericType )
101-
102105 override lazy val resolved =
103106 childrenResolved && checkInputDataTypes().isSuccess && ! DecimalType .isFixed(dataType)
104107
@@ -107,14 +110,13 @@ case class Add(left: Expression, right: Expression)
107110 protected override def nullSafeEval (input1 : Any , input2 : Any ): Any = numeric.plus(input1, input2)
108111}
109112
110- case class Subtract (left : Expression , right : Expression )
111- extends BinaryArithmetic with ExpectsInputTypes {
113+ case class Subtract (left : Expression , right : Expression ) extends BinaryArithmetic {
114+
115+ override def inputType : AbstractDataType = NumericType
112116
113117 override def symbol : String = " -"
114118 override def decimalMethod : String = " $minus"
115119
116- override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType , NumericType )
117-
118120 override lazy val resolved =
119121 childrenResolved && checkInputDataTypes().isSuccess && ! DecimalType .isFixed(dataType)
120122
@@ -123,14 +125,13 @@ case class Subtract(left: Expression, right: Expression)
123125 protected override def nullSafeEval (input1 : Any , input2 : Any ): Any = numeric.minus(input1, input2)
124126}
125127
126- case class Multiply (left : Expression , right : Expression )
127- extends BinaryArithmetic with ExpectsInputTypes {
128+ case class Multiply (left : Expression , right : Expression ) extends BinaryArithmetic {
129+
130+ override def inputType : AbstractDataType = NumericType
128131
129132 override def symbol : String = " *"
130133 override def decimalMethod : String = " $times"
131134
132- override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType , NumericType )
133-
134135 override lazy val resolved =
135136 childrenResolved && checkInputDataTypes().isSuccess && ! DecimalType .isFixed(dataType)
136137
@@ -139,15 +140,14 @@ case class Multiply(left: Expression, right: Expression)
139140 protected override def nullSafeEval (input1 : Any , input2 : Any ): Any = numeric.times(input1, input2)
140141}
141142
142- case class Divide (left : Expression , right : Expression )
143- extends BinaryArithmetic with ExpectsInputTypes {
143+ case class Divide (left : Expression , right : Expression ) extends BinaryArithmetic {
144+
145+ override def inputType : AbstractDataType = NumericType
144146
145147 override def symbol : String = " /"
146148 override def decimalMethod : String = " $div"
147149 override def nullable : Boolean = true
148150
149- override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType , NumericType )
150-
151151 override lazy val resolved =
152152 childrenResolved && checkInputDataTypes().isSuccess && ! DecimalType .isFixed(dataType)
153153
@@ -205,15 +205,14 @@ case class Divide(left: Expression, right: Expression)
205205 }
206206}
207207
208- case class Remainder (left : Expression , right : Expression )
209- extends BinaryArithmetic with ExpectsInputTypes {
208+ case class Remainder (left : Expression , right : Expression ) extends BinaryArithmetic {
209+
210+ override def inputType : AbstractDataType = NumericType
210211
211212 override def symbol : String = " %"
212213 override def decimalMethod : String = " remainder"
213214 override def nullable : Boolean = true
214215
215- override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType , NumericType )
216-
217216 override lazy val resolved =
218217 childrenResolved && checkInputDataTypes().isSuccess && ! DecimalType .isFixed(dataType)
219218
@@ -272,6 +271,10 @@ case class Remainder(left: Expression, right: Expression)
272271}
273272
274273case class MaxOf (left : Expression , right : Expression ) extends BinaryArithmetic {
274+ // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.
275+
276+ override def inputType : AbstractDataType = TypeCollection .Ordered
277+
275278 override def nullable : Boolean = left.nullable && right.nullable
276279
277280 protected def checkTypesInternal (t : DataType ) =
@@ -326,6 +329,10 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
326329}
327330
328331case class MinOf (left : Expression , right : Expression ) extends BinaryArithmetic {
332+ // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.
333+
334+ override def inputType : AbstractDataType = TypeCollection .Ordered
335+
329336 override def nullable : Boolean = left.nullable && right.nullable
330337
331338 protected def checkTypesInternal (t : DataType ) =
0 commit comments