@@ -29,11 +29,11 @@ abstract class UnaryArithmetic extends UnaryExpression {
2929 override def dataType : DataType = child.dataType
3030}
3131
32- case class UnaryMinus (child : Expression ) extends UnaryArithmetic {
33- override def toString : String = s " - $child"
32+ case class UnaryMinus (child : Expression ) extends UnaryArithmetic with ExpectsInputTypes {
33+
34+ override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType )
3435
35- override def checkInputDataTypes (): TypeCheckResult =
36- TypeUtils .checkForNumericExpr(child.dataType, " operator -" )
36+ override def toString : String = s " - $child"
3737
3838 private lazy val numeric = TypeUtils .getNumeric(dataType)
3939
@@ -57,9 +57,9 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic {
5757/**
5858 * A function that get the absolute value of the numeric value.
5959 */
60- case class Abs (child : Expression ) extends UnaryArithmetic {
61- override def checkInputDataTypes () : TypeCheckResult =
62- TypeUtils .checkForNumericExpr(child.dataType, " function abs " )
60+ case class Abs (child : Expression ) extends UnaryArithmetic with ExpectsInputTypes {
61+
62+ override def inputTypes : Seq [ AbstractDataType ] = Seq ( NumericType )
6363
6464 private lazy val numeric = TypeUtils .getNumeric(dataType)
6565
@@ -71,18 +71,6 @@ abstract class BinaryArithmetic extends BinaryOperator {
7171
7272 override def dataType : DataType = left.dataType
7373
74- override def checkInputDataTypes (): TypeCheckResult = {
75- if (left.dataType != right.dataType) {
76- TypeCheckResult .TypeCheckFailure (
77- s " differing types in ${this .getClass.getSimpleName} " +
78- s " ( ${left.dataType} and ${right.dataType}). " )
79- } else {
80- checkTypesInternal(dataType)
81- }
82- }
83-
84- protected def checkTypesInternal (t : DataType ): TypeCheckResult
85-
8674 /** Name of the function for this expression on a [[Decimal ]] type. */
8775 def decimalMethod : String =
8876 sys.error(" BinaryArithmetics must override either decimalMethod or genCode" )
@@ -103,63 +91,66 @@ private[sql] object BinaryArithmetic {
10391 def unapply (e : BinaryArithmetic ): Option [(Expression , Expression )] = Some ((e.left, e.right))
10492}
10593
106- case class Add (left : Expression , right : Expression ) extends BinaryArithmetic {
94+ case class Add (left : Expression , right : Expression )
95+ extends BinaryArithmetic with ExpectsInputTypes {
96+
10797 override def symbol : String = " +"
10898 override def decimalMethod : String = " $plus"
10999
100+ override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType , NumericType )
101+
110102 override lazy val resolved =
111103 childrenResolved && checkInputDataTypes().isSuccess && ! DecimalType .isFixed(dataType)
112104
113- protected def checkTypesInternal (t : DataType ) =
114- TypeUtils .checkForNumericExpr(t, " operator " + symbol)
115-
116105 private lazy val numeric = TypeUtils .getNumeric(dataType)
117106
118107 protected override def nullSafeEval (input1 : Any , input2 : Any ): Any = numeric.plus(input1, input2)
119108}
120109
121- case class Subtract (left : Expression , right : Expression ) extends BinaryArithmetic {
110+ case class Subtract (left : Expression , right : Expression )
111+ extends BinaryArithmetic with ExpectsInputTypes {
112+
122113 override def symbol : String = " -"
123114 override def decimalMethod : String = " $minus"
124115
116+ override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType , NumericType )
117+
125118 override lazy val resolved =
126119 childrenResolved && checkInputDataTypes().isSuccess && ! DecimalType .isFixed(dataType)
127120
128- protected def checkTypesInternal (t : DataType ) =
129- TypeUtils .checkForNumericExpr(t, " operator " + symbol)
130-
131121 private lazy val numeric = TypeUtils .getNumeric(dataType)
132122
133123 protected override def nullSafeEval (input1 : Any , input2 : Any ): Any = numeric.minus(input1, input2)
134124}
135125
136- case class Multiply (left : Expression , right : Expression ) extends BinaryArithmetic {
126+ case class Multiply (left : Expression , right : Expression )
127+ extends BinaryArithmetic with ExpectsInputTypes {
128+
137129 override def symbol : String = " *"
138130 override def decimalMethod : String = " $times"
139131
132+ override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType , NumericType )
133+
140134 override lazy val resolved =
141135 childrenResolved && checkInputDataTypes().isSuccess && ! DecimalType .isFixed(dataType)
142136
143- protected def checkTypesInternal (t : DataType ) =
144- TypeUtils .checkForNumericExpr(t, " operator " + symbol)
145-
146137 private lazy val numeric = TypeUtils .getNumeric(dataType)
147138
148139 protected override def nullSafeEval (input1 : Any , input2 : Any ): Any = numeric.times(input1, input2)
149140}
150141
151- case class Divide (left : Expression , right : Expression ) extends BinaryArithmetic {
142+ case class Divide (left : Expression , right : Expression )
143+ extends BinaryArithmetic with ExpectsInputTypes {
144+
152145 override def symbol : String = " /"
153146 override def decimalMethod : String = " $div"
154-
155147 override def nullable : Boolean = true
156148
149+ override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType , NumericType )
150+
157151 override lazy val resolved =
158152 childrenResolved && checkInputDataTypes().isSuccess && ! DecimalType .isFixed(dataType)
159153
160- protected def checkTypesInternal (t : DataType ) =
161- TypeUtils .checkForNumericExpr(t, " operator " + symbol)
162-
163154 private lazy val div : (Any , Any ) => Any = dataType match {
164155 case ft : FractionalType => ft.fractional.asInstanceOf [Fractional [Any ]].div
165156 case it : IntegralType => it.integral.asInstanceOf [Integral [Any ]].quot
@@ -214,18 +205,18 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
214205 }
215206}
216207
217- case class Remainder (left : Expression , right : Expression ) extends BinaryArithmetic {
208+ case class Remainder (left : Expression , right : Expression )
209+ extends BinaryArithmetic with ExpectsInputTypes {
210+
218211 override def symbol : String = " %"
219212 override def decimalMethod : String = " remainder"
220-
221213 override def nullable : Boolean = true
222214
215+ override def inputTypes : Seq [AbstractDataType ] = Seq (NumericType , NumericType )
216+
223217 override lazy val resolved =
224218 childrenResolved && checkInputDataTypes().isSuccess && ! DecimalType .isFixed(dataType)
225219
226- protected def checkTypesInternal (t : DataType ) =
227- TypeUtils .checkForNumericExpr(t, " operator " + symbol)
228-
229220 private lazy val integral = dataType match {
230221 case i : IntegralType => i.integral.asInstanceOf [Integral [Any ]]
231222 case i : FractionalType => i.asIntegral.asInstanceOf [Integral [Any ]]
0 commit comments