Skip to content

Commit d017861

Browse files
committed
Improve expression type checking.
1 parent 3009088 commit d017861

File tree

6 files changed

+87
-86
lines changed

6 files changed

+87
-86
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,20 @@ import org.apache.spark.sql.catalyst.trees
2424
import org.apache.spark.sql.catalyst.trees.TreeNode
2525
import org.apache.spark.sql.types._
2626

27+
////////////////////////////////////////////////////////////////////////////////////////////////////
28+
// This file defines the basic expression abstract classes in Catalyst, including:
29+
// Expression: the base expression abstract class
30+
// LeafExpression
31+
// UnaryExpression
32+
// BinaryExpression
33+
// BinaryOperator
34+
//
35+
// For details, see their classdocs.
36+
////////////////////////////////////////////////////////////////////////////////////////////////////
2737

2838
/**
39+
* An expression in Catalyst.
40+
*
2941
* If an expression wants to be exposed in the function registry (so users can call it with
3042
* "name(arguments...)", the concrete implementation must be a case class whose constructor
3143
* arguments are all Expressions types.
@@ -335,15 +347,31 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
335347

336348

337349
/**
338-
* An expression that has two inputs that are expected to the be same type. If the two inputs have
339-
* different types, the analyzer will find the tightest common type and do the proper type casting.
350+
* A [[BinaryExpression]] that is an operator, with two properties:
351+
*
352+
* 1. The string representation is "x symbol y", rather than "funcName(x, y)".
353+
* 2. Two inputs are expected to the be same type. If the two inputs have different types,
354+
* the analyzer will find the tightest common type and do the proper type casting.
340355
*/
341356
abstract class BinaryOperator extends BinaryExpression {
342357
self: Product =>
343358

344359
def symbol: String
345360

346-
override def toString: String = s"($left $symbol $right)"
361+
override def checkInputDataTypes(): TypeCheckResult = {
362+
// First call the checker for ExpectsInputTypes, and then check whether left and right have
363+
// the same type.
364+
super.checkInputDataTypes() match {
365+
case TypeCheckResult.TypeCheckSuccess =>
366+
if (left.dataType != right.dataType) {
367+
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
368+
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
369+
} else {
370+
TypeCheckResult.TypeCheckSuccess
371+
}
372+
case TypeCheckResult.TypeCheckFailure(msg) => TypeCheckResult.TypeCheckFailure(msg)
373+
}
374+
}
347375
}
348376

349377

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 32 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ abstract class UnaryArithmetic extends UnaryExpression {
2929
override def dataType: DataType = child.dataType
3030
}
3131

32-
case class UnaryMinus(child: Expression) extends UnaryArithmetic {
33-
override def toString: String = s"-$child"
32+
case class UnaryMinus(child: Expression) extends UnaryArithmetic with ExpectsInputTypes {
33+
34+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
3435

35-
override def checkInputDataTypes(): TypeCheckResult =
36-
TypeUtils.checkForNumericExpr(child.dataType, "operator -")
36+
override def toString: String = s"-$child"
3737

3838
private lazy val numeric = TypeUtils.getNumeric(dataType)
3939

@@ -57,9 +57,9 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic {
5757
/**
5858
* A function that get the absolute value of the numeric value.
5959
*/
60-
case class Abs(child: Expression) extends UnaryArithmetic {
61-
override def checkInputDataTypes(): TypeCheckResult =
62-
TypeUtils.checkForNumericExpr(child.dataType, "function abs")
60+
case class Abs(child: Expression) extends UnaryArithmetic with ExpectsInputTypes {
61+
62+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
6363

6464
private lazy val numeric = TypeUtils.getNumeric(dataType)
6565

@@ -71,18 +71,6 @@ abstract class BinaryArithmetic extends BinaryOperator {
7171

7272
override def dataType: DataType = left.dataType
7373

74-
override def checkInputDataTypes(): TypeCheckResult = {
75-
if (left.dataType != right.dataType) {
76-
TypeCheckResult.TypeCheckFailure(
77-
s"differing types in ${this.getClass.getSimpleName} " +
78-
s"(${left.dataType} and ${right.dataType}).")
79-
} else {
80-
checkTypesInternal(dataType)
81-
}
82-
}
83-
84-
protected def checkTypesInternal(t: DataType): TypeCheckResult
85-
8674
/** Name of the function for this expression on a [[Decimal]] type. */
8775
def decimalMethod: String =
8876
sys.error("BinaryArithmetics must override either decimalMethod or genCode")
@@ -103,63 +91,66 @@ private[sql] object BinaryArithmetic {
10391
def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right))
10492
}
10593

106-
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
94+
case class Add(left: Expression, right: Expression)
95+
extends BinaryArithmetic with ExpectsInputTypes {
96+
10797
override def symbol: String = "+"
10898
override def decimalMethod: String = "$plus"
10999

100+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType)
101+
110102
override lazy val resolved =
111103
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
112104

113-
protected def checkTypesInternal(t: DataType) =
114-
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
115-
116105
private lazy val numeric = TypeUtils.getNumeric(dataType)
117106

118107
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2)
119108
}
120109

121-
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
110+
case class Subtract(left: Expression, right: Expression)
111+
extends BinaryArithmetic with ExpectsInputTypes {
112+
122113
override def symbol: String = "-"
123114
override def decimalMethod: String = "$minus"
124115

116+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType)
117+
125118
override lazy val resolved =
126119
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
127120

128-
protected def checkTypesInternal(t: DataType) =
129-
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
130-
131121
private lazy val numeric = TypeUtils.getNumeric(dataType)
132122

133123
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2)
134124
}
135125

136-
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
126+
case class Multiply(left: Expression, right: Expression)
127+
extends BinaryArithmetic with ExpectsInputTypes {
128+
137129
override def symbol: String = "*"
138130
override def decimalMethod: String = "$times"
139131

132+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType)
133+
140134
override lazy val resolved =
141135
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
142136

143-
protected def checkTypesInternal(t: DataType) =
144-
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
145-
146137
private lazy val numeric = TypeUtils.getNumeric(dataType)
147138

148139
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
149140
}
150141

151-
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
142+
case class Divide(left: Expression, right: Expression)
143+
extends BinaryArithmetic with ExpectsInputTypes {
144+
152145
override def symbol: String = "/"
153146
override def decimalMethod: String = "$div"
154-
155147
override def nullable: Boolean = true
156148

149+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType)
150+
157151
override lazy val resolved =
158152
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
159153

160-
protected def checkTypesInternal(t: DataType) =
161-
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
162-
163154
private lazy val div: (Any, Any) => Any = dataType match {
164155
case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
165156
case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot
@@ -214,18 +205,18 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
214205
}
215206
}
216207

217-
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
208+
case class Remainder(left: Expression, right: Expression)
209+
extends BinaryArithmetic with ExpectsInputTypes {
210+
218211
override def symbol: String = "%"
219212
override def decimalMethod: String = "remainder"
220-
221213
override def nullable: Boolean = true
222214

215+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, NumericType)
216+
223217
override lazy val resolved =
224218
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
225219

226-
protected def checkTypesInternal(t: DataType) =
227-
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
228-
229220
private lazy val integral = dataType match {
230221
case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
231222
case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]]

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwise.scala

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2120
import org.apache.spark.sql.catalyst.expressions.codegen._
22-
import org.apache.spark.sql.catalyst.util.TypeUtils
2321
import org.apache.spark.sql.types._
2422

2523

@@ -28,11 +26,13 @@ import org.apache.spark.sql.types._
2826
*
2927
* Code generation inherited from BinaryArithmetic.
3028
*/
31-
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
29+
case class BitwiseAnd(left: Expression, right: Expression)
30+
extends BinaryArithmetic with ExpectsInputTypes {
31+
3232
override def symbol: String = "&"
3333

34-
protected def checkTypesInternal(t: DataType) =
35-
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
34+
override def inputTypes: Seq[AbstractDataType] =
35+
Seq(TypeCollection(IntegerType, LongType), TypeCollection(IntegerType, LongType))
3636

3737
private lazy val and: (Any, Any) => Any = dataType match {
3838
case ByteType =>
@@ -53,11 +53,13 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
5353
*
5454
* Code generation inherited from BinaryArithmetic.
5555
*/
56-
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
56+
case class BitwiseOr(left: Expression, right: Expression)
57+
extends BinaryArithmetic with ExpectsInputTypes {
58+
5759
override def symbol: String = "|"
5860

59-
protected def checkTypesInternal(t: DataType) =
60-
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
61+
override def inputTypes: Seq[AbstractDataType] =
62+
Seq(TypeCollection(IntegerType, LongType), TypeCollection(IntegerType, LongType))
6163

6264
private lazy val or: (Any, Any) => Any = dataType match {
6365
case ByteType =>
@@ -78,11 +80,13 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
7880
*
7981
* Code generation inherited from BinaryArithmetic.
8082
*/
81-
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
83+
case class BitwiseXor(left: Expression, right: Expression)
84+
extends BinaryArithmetic with ExpectsInputTypes {
85+
8286
override def symbol: String = "^"
8387

84-
protected def checkTypesInternal(t: DataType) =
85-
TypeUtils.checkForBitwiseExpr(t, "operator " + symbol)
88+
override def inputTypes: Seq[AbstractDataType] =
89+
Seq(TypeCollection(IntegerType, LongType), TypeCollection(IntegerType, LongType))
8690

8791
private lazy val xor: (Any, Any) => Any = dataType match {
8892
case ByteType =>
@@ -101,11 +105,11 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
101105
/**
102106
* A function that calculates bitwise not(~) of a number.
103107
*/
104-
case class BitwiseNot(child: Expression) extends UnaryArithmetic {
108+
case class BitwiseNot(child: Expression) extends UnaryArithmetic with ExpectsInputTypes {
109+
105110
override def toString: String = s"~$child"
106111

107-
override def checkInputDataTypes(): TypeCheckResult =
108-
TypeUtils.checkForBitwiseExpr(child.dataType, "operator ~")
112+
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType))
109113

110114
private lazy val not: (Any) => Any = dataType match {
111115
case ByteType =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -213,18 +213,6 @@ case class Or(left: Expression, right: Expression)
213213
abstract class BinaryComparison extends BinaryOperator with Predicate {
214214
self: Product =>
215215

216-
override def checkInputDataTypes(): TypeCheckResult = {
217-
if (left.dataType != right.dataType) {
218-
TypeCheckResult.TypeCheckFailure(
219-
s"differing types in ${this.getClass.getSimpleName} " +
220-
s"(${left.dataType} and ${right.dataType}).")
221-
} else {
222-
checkTypesInternal(dataType)
223-
}
224-
}
225-
226-
protected def checkTypesInternal(t: DataType): TypeCheckResult
227-
228216
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
229217
if (ctx.isPrimitiveType(left.dataType)) {
230218
// faster version
@@ -251,8 +239,6 @@ private[sql] object Equality {
251239
case class EqualTo(left: Expression, right: Expression) extends BinaryComparison {
252240
override def symbol: String = "="
253241

254-
override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess
255-
256242
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
257243
if (left.dataType != BinaryType) input1 == input2
258244
else java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
@@ -268,8 +254,6 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
268254

269255
override def nullable: Boolean = false
270256

271-
override protected def checkTypesInternal(t: DataType) = TypeCheckResult.TypeCheckSuccess
272-
273257
override def eval(input: InternalRow): Any = {
274258
val input1 = left.eval(input)
275259
val input2 = right.eval(input)
@@ -301,7 +285,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
301285
case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
302286
override def symbol: String = "<"
303287

304-
override protected def checkTypesInternal(t: DataType) =
288+
protected def checkTypesInternal(t: DataType) =
305289
TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)
306290

307291
private lazy val ordering = TypeUtils.getOrdering(left.dataType)
@@ -312,7 +296,7 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
312296
case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
313297
override def symbol: String = "<="
314298

315-
override protected def checkTypesInternal(t: DataType) =
299+
protected def checkTypesInternal(t: DataType) =
316300
TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)
317301

318302
private lazy val ordering = TypeUtils.getOrdering(left.dataType)
@@ -323,7 +307,7 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
323307
case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
324308
override def symbol: String = ">"
325309

326-
override protected def checkTypesInternal(t: DataType) =
310+
protected def checkTypesInternal(t: DataType) =
327311
TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)
328312

329313
private lazy val ordering = TypeUtils.getOrdering(left.dataType)
@@ -334,7 +318,7 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar
334318
case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
335319
override def symbol: String = ">="
336320

337-
override protected def checkTypesInternal(t: DataType) =
321+
protected def checkTypesInternal(t: DataType) =
338322
TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)
339323

340324
private lazy val ordering = TypeUtils.getOrdering(left.dataType)

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,6 @@ object TypeUtils {
3232
}
3333
}
3434

35-
def checkForBitwiseExpr(t: DataType, caller: String): TypeCheckResult = {
36-
if (t.isInstanceOf[IntegralType] || t == NullType) {
37-
TypeCheckResult.TypeCheckSuccess
38-
} else {
39-
TypeCheckResult.TypeCheckFailure(s"$caller accepts integral types, not $t")
40-
}
41-
}
42-
4335
def checkForOrderingExpr(t: DataType, caller: String): TypeCheckResult = {
4436
if (t.isInstanceOf[AtomicType] || t == NullType) {
4537
TypeCheckResult.TypeCheckSuccess

sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ private[sql] class TypeCollection(private val types: Seq[AbstractDataType])
9696

9797
private[sql] object TypeCollection {
9898

99+
val Ordered = TypeCollection(NumericType, StringType)
100+
99101
def apply(types: AbstractDataType*): TypeCollection = new TypeCollection(types)
100102

101103
def unapply(typ: AbstractDataType): Option[Seq[AbstractDataType]] = typ match {

0 commit comments

Comments
 (0)