Skip to content

Commit f783c19

Browse files
committed
add default type-check for BinaryOperator
1 parent e14b545 commit f783c19

File tree

13 files changed

+72
-84
lines changed

13 files changed

+72
-84
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/AbstractSparkSQLParser.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class SqlLexical extends StdLexical {
9696
",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>"
9797
)
9898

99-
protected override def processIdent(name: String) = {
99+
override protected def processIdent(name: String) = {
100100
val token = normalizeKeyword(name)
101101
if (reserved contains token) Keyword(token) else Identifier(name)
102102
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
419419

420420
private[this] lazy val cast: Any => Any = cast(child.dataType, dataType)
421421

422-
protected override def nullSafeEval(input: Any): Any = cast(input)
422+
override protected def nullSafeEval(input: Any): Any = cast(input)
423423

424424
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
425425
// TODO: Add support for more data types.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,18 @@ abstract class BinaryOperator extends BinaryExpression {
344344
def symbol: String
345345

346346
override def toString: String = s"($left $symbol $right)"
347+
348+
override def checkInputDataTypes(): TypeCheckResult = {
349+
if (left.dataType != right.dataType) {
350+
TypeCheckResult.TypeCheckFailure(
351+
s"differing types in ${this.getClass.getSimpleName} " +
352+
s"(${left.dataType} and ${right.dataType}).")
353+
} else {
354+
checkTypesInternal(dataType)
355+
}
356+
}
357+
358+
protected def checkTypesInternal(t: DataType): TypeCheckResult
347359
}
348360

349361

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {
4242
case dt: NumericType => defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dt)})(-($c))")
4343
}
4444

45-
protected override def nullSafeEval(input: Any): Any = numeric.negate(input)
45+
override protected def nullSafeEval(input: Any): Any = numeric.negate(input)
4646
}
4747

4848
case class UnaryPositive(child: Expression) extends UnaryArithmetic {
@@ -51,7 +51,7 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic {
5151
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
5252
defineCodeGen(ctx, ev, c => c)
5353

54-
protected override def nullSafeEval(input: Any): Any = input
54+
override protected def nullSafeEval(input: Any): Any = input
5555
}
5656

5757
/**
@@ -63,26 +63,14 @@ case class Abs(child: Expression) extends UnaryArithmetic {
6363

6464
private lazy val numeric = TypeUtils.getNumeric(dataType)
6565

66-
protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
66+
override protected def nullSafeEval(input: Any): Any = numeric.abs(input)
6767
}
6868

6969
abstract class BinaryArithmetic extends BinaryOperator {
7070
self: Product =>
7171

7272
override def dataType: DataType = left.dataType
7373

74-
override def checkInputDataTypes(): TypeCheckResult = {
75-
if (left.dataType != right.dataType) {
76-
TypeCheckResult.TypeCheckFailure(
77-
s"differing types in ${this.getClass.getSimpleName} " +
78-
s"(${left.dataType} and ${right.dataType}).")
79-
} else {
80-
checkTypesInternal(dataType)
81-
}
82-
}
83-
84-
protected def checkTypesInternal(t: DataType): TypeCheckResult
85-
8674
/** Name of the function for this expression on a [[Decimal]] type. */
8775
def decimalMethod: String =
8876
sys.error("BinaryArithmetics must override either decimalMethod or genCode")
@@ -110,12 +98,12 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
11098
override lazy val resolved =
11199
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
112100

113-
protected def checkTypesInternal(t: DataType) =
101+
override protected def checkTypesInternal(t: DataType) =
114102
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
115103

116104
private lazy val numeric = TypeUtils.getNumeric(dataType)
117105

118-
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2)
106+
override protected def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2)
119107
}
120108

121109
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
@@ -125,12 +113,12 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
125113
override lazy val resolved =
126114
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
127115

128-
protected def checkTypesInternal(t: DataType) =
116+
override protected def checkTypesInternal(t: DataType) =
129117
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
130118

131119
private lazy val numeric = TypeUtils.getNumeric(dataType)
132120

133-
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2)
121+
override protected def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2)
134122
}
135123

136124
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
@@ -140,12 +128,12 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
140128
override lazy val resolved =
141129
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
142130

143-
protected def checkTypesInternal(t: DataType) =
131+
override protected def checkTypesInternal(t: DataType) =
144132
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
145133

146134
private lazy val numeric = TypeUtils.getNumeric(dataType)
147135

148-
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
136+
override protected def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
149137
}
150138

151139
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
@@ -157,7 +145,7 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
157145
override lazy val resolved =
158146
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
159147

160-
protected def checkTypesInternal(t: DataType) =
148+
override protected def checkTypesInternal(t: DataType) =
161149
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
162150

163151
private lazy val div: (Any, Any) => Any = dataType match {
@@ -223,7 +211,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
223211
override lazy val resolved =
224212
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
225213

226-
protected def checkTypesInternal(t: DataType) =
214+
override protected def checkTypesInternal(t: DataType) =
227215
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
228216

229217
private lazy val integral = dataType match {
@@ -283,7 +271,7 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
283271
case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
284272
override def nullable: Boolean = left.nullable && right.nullable
285273

286-
protected def checkTypesInternal(t: DataType) =
274+
override protected def checkTypesInternal(t: DataType) =
287275
TypeUtils.checkForOrderingExpr(t, "function maxOf")
288276

289277
private lazy val ordering = TypeUtils.getOrdering(dataType)
@@ -337,7 +325,7 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
337325
case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
338326
override def nullable: Boolean = left.nullable && right.nullable
339327

340-
protected def checkTypesInternal(t: DataType) =
328+
override protected def checkTypesInternal(t: DataType) =
341329
TypeUtils.checkForOrderingExpr(t, "function minOf")
342330

343331
private lazy val ordering = TypeUtils.getOrdering(dataType)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.sql.types._
3131
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
3232
override def symbol: String = "&"
3333

34-
protected def checkTypesInternal(t: DataType) =
34+
override protected def checkTypesInternal(t: DataType) =
3535
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
3636

3737
private lazy val and: (Any, Any) => Any = dataType match {
@@ -45,7 +45,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
4545
((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any]
4646
}
4747

48-
protected override def nullSafeEval(input1: Any, input2: Any): Any = and(input1, input2)
48+
override protected def nullSafeEval(input1: Any, input2: Any): Any = and(input1, input2)
4949
}
5050

5151
/**
@@ -56,7 +56,7 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
5656
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
5757
override def symbol: String = "|"
5858

59-
protected def checkTypesInternal(t: DataType) =
59+
override protected def checkTypesInternal(t: DataType) =
6060
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
6161

6262
private lazy val or: (Any, Any) => Any = dataType match {
@@ -70,7 +70,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
7070
((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any]
7171
}
7272

73-
protected override def nullSafeEval(input1: Any, input2: Any): Any = or(input1, input2)
73+
override protected def nullSafeEval(input1: Any, input2: Any): Any = or(input1, input2)
7474
}
7575

7676
/**
@@ -81,7 +81,7 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
8181
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
8282
override def symbol: String = "^"
8383

84-
protected def checkTypesInternal(t: DataType) =
84+
override protected def checkTypesInternal(t: DataType) =
8585
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
8686

8787
private lazy val xor: (Any, Any) => Any = dataType match {
@@ -95,7 +95,7 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
9595
((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any]
9696
}
9797

98-
protected override def nullSafeEval(input1: Any, input2: Any): Any = xor(input1, input2)
98+
override protected def nullSafeEval(input1: Any, input2: Any): Any = xor(input1, input2)
9999
}
100100

101101
/**
@@ -122,5 +122,5 @@ case class BitwiseNot(child: Expression) extends UnaryArithmetic {
122122
defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)")
123123
}
124124

125-
protected override def nullSafeEval(input: Any): Any = not(input)
125+
override protected def nullSafeEval(input: Any): Any = not(input)
126126
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ case class GetStructField(child: Expression, field: StructField, ordinal: Int)
130130
override def dataType: DataType = field.dataType
131131
override def nullable: Boolean = child.nullable || field.nullable
132132

133-
protected override def nullSafeEval(input: Any): Any =
133+
override protected def nullSafeEval(input: Any): Any =
134134
input.asInstanceOf[InternalRow](ordinal)
135135

136136
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -160,7 +160,7 @@ case class GetArrayStructFields(
160160
override def dataType: DataType = ArrayType(field.dataType, containsNull)
161161
override def nullable: Boolean = child.nullable || containsNull || field.nullable
162162

163-
protected override def nullSafeEval(input: Any): Any = {
163+
override protected def nullSafeEval(input: Any): Any = {
164164
input.asInstanceOf[Seq[InternalRow]].map { row =>
165165
if (row == null) null else row(ordinal)
166166
}
@@ -204,7 +204,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression)
204204

205205
override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType
206206

207-
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
207+
override protected def nullSafeEval(value: Any, ordinal: Any): Any = {
208208
// TODO: consider using Array[_] for ArrayType child to avoid
209209
// boxing of primitives
210210
val baseValue = value.asInstanceOf[Seq[_]]
@@ -248,7 +248,7 @@ case class GetMapValue(child: Expression, key: Expression)
248248

249249
override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType
250250

251-
protected override def nullSafeEval(value: Any, ordinal: Any): Any = {
251+
override protected def nullSafeEval(value: Any, ordinal: Any): Any = {
252252
val baseValue = value.asInstanceOf[Map[Any, _]]
253253
baseValue.get(ordinal).orNull
254254
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalFunctions.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression {
3030
override def dataType: DataType = LongType
3131
override def toString: String = s"UnscaledValue($child)"
3232

33-
protected override def nullSafeEval(input: Any): Any =
33+
override protected def nullSafeEval(input: Any): Any =
3434
input.asInstanceOf[Decimal].toUnscaledLong
3535

3636
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -48,7 +48,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un
4848
override def dataType: DataType = DecimalType(precision, scale)
4949
override def toString: String = s"MakeDecimal($child,$precision,$scale)"
5050

51-
protected override def nullSafeEval(input: Any): Any =
51+
override protected def nullSafeEval(input: Any): Any =
5252
Decimal(input.asInstanceOf[Long], precision, scale)
5353

5454
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ abstract class UnaryMathExpression(f: Double => Double, name: String)
6262
override def nullable: Boolean = true
6363
override def toString: String = s"$name($child)"
6464

65-
protected override def nullSafeEval(input: Any): Any = {
65+
override protected def nullSafeEval(input: Any): Any = {
6666
val result = f(input.asInstanceOf[Double])
6767
if (result.isNaN) null else result
6868
}
@@ -97,7 +97,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String)
9797

9898
override def dataType: DataType = DoubleType
9999

100-
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
100+
override protected def nullSafeEval(input1: Any, input2: Any): Any = {
101101
val result = f(input1.asInstanceOf[Double], input2.asInstanceOf[Double])
102102
if (result.isNaN) null else result
103103
}
@@ -183,7 +183,7 @@ case class Factorial(child: Expression) extends UnaryExpression with ExpectsInpu
183183
// If the value not in the range of [0, 20], it still will be null, so set it to be true here.
184184
override def nullable: Boolean = true
185185

186-
protected override def nullSafeEval(input: Any): Any = {
186+
override protected def nullSafeEval(input: Any): Any = {
187187
val value = input.asInstanceOf[jl.Integer]
188188
if (value > 20 || value < 0) {
189189
null
@@ -256,7 +256,7 @@ case class Bin(child: Expression)
256256
override def inputTypes: Seq[DataType] = Seq(LongType)
257257
override def dataType: DataType = StringType
258258

259-
protected override def nullSafeEval(input: Any): Any =
259+
override protected def nullSafeEval(input: Any): Any =
260260
UTF8String.fromString(jl.Long.toBinaryString(input.asInstanceOf[Long]))
261261

262262
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -293,7 +293,7 @@ case class Hex(child: Expression) extends UnaryExpression with ExpectsInputTypes
293293

294294
override def dataType: DataType = StringType
295295

296-
protected override def nullSafeEval(num: Any): Any = child.dataType match {
296+
override protected def nullSafeEval(num: Any): Any = child.dataType match {
297297
case LongType => hex(num.asInstanceOf[Long])
298298
case BinaryType => hex(num.asInstanceOf[Array[Byte]])
299299
case StringType => hex(num.asInstanceOf[UTF8String].getBytes)
@@ -337,7 +337,7 @@ case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTyp
337337
override def nullable: Boolean = true
338338
override def dataType: DataType = BinaryType
339339

340-
protected override def nullSafeEval(num: Any): Any =
340+
override protected def nullSafeEval(num: Any): Any =
341341
unhex(num.asInstanceOf[UTF8String].getBytes)
342342

343343
private[this] def unhex(bytes: Array[Byte]): Array[Byte] = {
@@ -383,7 +383,7 @@ case class Unhex(child: Expression) extends UnaryExpression with ExpectsInputTyp
383383
case class Atan2(left: Expression, right: Expression)
384384
extends BinaryMathExpression(math.atan2, "ATAN2") {
385385

386-
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
386+
override protected def nullSafeEval(input1: Any, input2: Any): Any = {
387387
// With codegen, the values returned by -0.0 and 0.0 are different. Handled with +0.0
388388
val result = math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0)
389389
if (result.isNaN) null else result
@@ -423,7 +423,7 @@ case class ShiftLeft(left: Expression, right: Expression)
423423

424424
override def dataType: DataType = left.dataType
425425

426-
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
426+
override protected def nullSafeEval(input1: Any, input2: Any): Any = {
427427
input1 match {
428428
case l: jl.Long => l << input2.asInstanceOf[jl.Integer]
429429
case i: jl.Integer => i << input2.asInstanceOf[jl.Integer]
@@ -449,7 +449,7 @@ case class ShiftRight(left: Expression, right: Expression)
449449

450450
override def dataType: DataType = left.dataType
451451

452-
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
452+
override protected def nullSafeEval(input1: Any, input2: Any): Any = {
453453
input1 match {
454454
case l: jl.Long => l >> input2.asInstanceOf[jl.Integer]
455455
case i: jl.Integer => i >> input2.asInstanceOf[jl.Integer]
@@ -475,7 +475,7 @@ case class ShiftRightUnsigned(left: Expression, right: Expression)
475475

476476
override def dataType: DataType = left.dataType
477477

478-
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
478+
override protected def nullSafeEval(input1: Any, input2: Any): Any = {
479479
input1 match {
480480
case l: jl.Long => l >>> input2.asInstanceOf[jl.Integer]
481481
case i: jl.Integer => i >>> input2.asInstanceOf[jl.Integer]

0 commit comments

Comments
 (0)