1818package org .apache .spark .sql .catalyst .expressions
1919
2020import org .apache .spark .sql .catalyst .InternalRow
21- import org .apache .spark .sql .catalyst .analysis .TypeCheckResult
2221import org .apache .spark .sql .catalyst .expressions .codegen .{CodeGenContext , GeneratedExpressionCode }
2322import org .apache .spark .sql .catalyst .util .TypeUtils
2423import org .apache .spark .sql .types ._
2524
26- abstract class UnaryArithmetic extends UnaryExpression {
27- self : Product =>
25+
26+ case class UnaryMinus (child : Expression ) extends UnaryExpression with ExpectsInputTypes {
27+
28+ override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType )
2829
2930 override def dataType : DataType = child.dataType
30- }
3131
32- case class UnaryMinus (child : Expression ) extends UnaryArithmetic {
3332 override def toString : String = s " - $child"
3433
35- override def checkInputDataTypes (): TypeCheckResult =
36- TypeUtils .checkForNumericExpr(child.dataType, " operator -" )
37-
3834 private lazy val numeric = TypeUtils .getNumeric(dataType)
3935
4036 override def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String = dataType match {
@@ -45,9 +41,13 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {
4541 protected override def nullSafeEval (input : Any ): Any = numeric.negate(input)
4642}
4743
48- case class UnaryPositive (child : Expression ) extends UnaryArithmetic {
44+ case class UnaryPositive (child : Expression ) extends UnaryExpression with ExpectsInputTypes {
4945 override def prettyName : String = " positive"
5046
47+ override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType )
48+
49+ override def dataType : DataType = child.dataType
50+
5151 override def genCode (ctx : CodeGenContext , ev : GeneratedExpressionCode ): String =
5252 defineCodeGen(ctx, ev, c => c)
5353
@@ -57,9 +57,11 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic {
5757/**
5858 * A function that get the absolute value of the numeric value.
5959 */
60- case class Abs (child : Expression ) extends UnaryArithmetic {
61- override def checkInputDataTypes (): TypeCheckResult =
62- TypeUtils .checkForNumericExpr(child.dataType, " function abs" )
60+ case class Abs (child : Expression ) extends UnaryExpression with ExpectsInputTypes {
61+
62+ override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType )
63+
64+ override def dataType : DataType = child.dataType
6365
6466 private lazy val numeric = TypeUtils .getNumeric(dataType)
6567
@@ -71,18 +73,6 @@ abstract class BinaryArithmetic extends BinaryOperator {
7173
7274 override def dataType : DataType = left.dataType
7375
74- override def checkInputDataTypes (): TypeCheckResult = {
75- if (left.dataType != right.dataType) {
76- TypeCheckResult .TypeCheckFailure (
77- s " differing types in ${this .getClass.getSimpleName} " +
78- s " ( ${left.dataType} and ${right.dataType}). " )
79- } else {
80- checkTypesInternal(dataType)
81- }
82- }
83-
84- protected def checkTypesInternal (t : DataType ): TypeCheckResult
85-
8676 /** Name of the function for this expression on a [[Decimal ]] type. */
8777 def decimalMethod : String =
8878 sys.error(" BinaryArithmetics must override either decimalMethod or genCode" )
@@ -104,62 +94,61 @@ private[sql] object BinaryArithmetic {
10494}
10595
10696case class Add (left : Expression , right : Expression ) extends BinaryArithmetic {
97+
98+ override def inputType : AbstractDataType = NumericType
99+
107100 override def symbol : String = " +"
108101 override def decimalMethod : String = " $plus"
109102
110103 override lazy val resolved =
111104 childrenResolved && checkInputDataTypes().isSuccess && ! DecimalType .isFixed(dataType)
112105
113- protected def checkTypesInternal (t : DataType ) =
114- TypeUtils .checkForNumericExpr(t, " operator " + symbol)
115-
116106 private lazy val numeric = TypeUtils .getNumeric(dataType)
117107
118108 protected override def nullSafeEval (input1 : Any , input2 : Any ): Any = numeric.plus(input1, input2)
119109}
120110
121111case class Subtract (left : Expression , right : Expression ) extends BinaryArithmetic {
112+
113+ override def inputType : AbstractDataType = NumericType
114+
122115 override def symbol : String = " -"
123116 override def decimalMethod : String = " $minus"
124117
125118 override lazy val resolved =
126119 childrenResolved && checkInputDataTypes().isSuccess && ! DecimalType .isFixed(dataType)
127120
128- protected def checkTypesInternal (t : DataType ) =
129- TypeUtils .checkForNumericExpr(t, " operator " + symbol)
130-
131121 private lazy val numeric = TypeUtils .getNumeric(dataType)
132122
133123 protected override def nullSafeEval (input1 : Any , input2 : Any ): Any = numeric.minus(input1, input2)
134124}
135125
136126case class Multiply (left : Expression , right : Expression ) extends BinaryArithmetic {
127+
128+ override def inputType : AbstractDataType = NumericType
129+
137130 override def symbol : String = " *"
138131 override def decimalMethod : String = " $times"
139132
140133 override lazy val resolved =
141134 childrenResolved && checkInputDataTypes().isSuccess && ! DecimalType .isFixed(dataType)
142135
143- protected def checkTypesInternal (t : DataType ) =
144- TypeUtils .checkForNumericExpr(t, " operator " + symbol)
145-
146136 private lazy val numeric = TypeUtils .getNumeric(dataType)
147137
148138 protected override def nullSafeEval (input1 : Any , input2 : Any ): Any = numeric.times(input1, input2)
149139}
150140
151141case class Divide (left : Expression , right : Expression ) extends BinaryArithmetic {
142+
143+ override def inputType : AbstractDataType = NumericType
144+
152145 override def symbol : String = " /"
153146 override def decimalMethod : String = " $div"
154-
155147 override def nullable : Boolean = true
156148
157149 override lazy val resolved =
158150 childrenResolved && checkInputDataTypes().isSuccess && ! DecimalType .isFixed(dataType)
159151
160- protected def checkTypesInternal (t : DataType ) =
161- TypeUtils .checkForNumericExpr(t, " operator " + symbol)
162-
163152 private lazy val div : (Any , Any ) => Any = dataType match {
164153 case ft : FractionalType => ft.fractional.asInstanceOf [Fractional [Any ]].div
165154 case it : IntegralType => it.integral.asInstanceOf [Integral [Any ]].quot
@@ -215,17 +204,16 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
215204}
216205
217206case class Remainder (left : Expression , right : Expression ) extends BinaryArithmetic {
207+
208+ override def inputType : AbstractDataType = NumericType
209+
218210 override def symbol : String = " %"
219211 override def decimalMethod : String = " remainder"
220-
221212 override def nullable : Boolean = true
222213
223214 override lazy val resolved =
224215 childrenResolved && checkInputDataTypes().isSuccess && ! DecimalType .isFixed(dataType)
225216
226- protected def checkTypesInternal (t : DataType ) =
227- TypeUtils .checkForNumericExpr(t, " operator " + symbol)
228-
229217 private lazy val integral = dataType match {
230218 case i : IntegralType => i.integral.asInstanceOf [Integral [Any ]]
231219 case i : FractionalType => i.asIntegral.asInstanceOf [Integral [Any ]]
@@ -281,10 +269,11 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
281269}
282270
283271case class MaxOf (left : Expression , right : Expression ) extends BinaryArithmetic {
284- override def nullable : Boolean = left.nullable && right.nullable
272+ // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.
285273
286- protected def checkTypesInternal (t : DataType ) =
287- TypeUtils .checkForOrderingExpr(t, " function maxOf" )
274+ override def inputType : AbstractDataType = TypeCollection .Ordered
275+
276+ override def nullable : Boolean = left.nullable && right.nullable
288277
289278 private lazy val ordering = TypeUtils .getOrdering(dataType)
290279
@@ -335,10 +324,11 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
335324}
336325
337326case class MinOf (left : Expression , right : Expression ) extends BinaryArithmetic {
338- override def nullable : Boolean = left.nullable && right.nullable
327+ // TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.
339328
340- protected def checkTypesInternal (t : DataType ) =
341- TypeUtils .checkForOrderingExpr(t, " function minOf" )
329+ override def inputType : AbstractDataType = TypeCollection .Ordered
330+
331+ override def nullable : Boolean = left.nullable && right.nullable
342332
343333 private lazy val ordering = TypeUtils .getOrdering(dataType)
344334
0 commit comments