Skip to content

Commit f23a721

Browse files
committed
[SPARK-8993][SQL] More comprehensive type checking in expressions.
This patch makes the following changes: 1. ExpectsInputTypes only defines expected input types, but does not perform any implicit type casting. 2. ImplicitCastInputTypes is a new trait that defines both expected input types, as well as performs implicit type casting. 3. BinaryOperator has a new abstract function "inputType", which defines the expected input type for both left/right. Concrete BinaryOperator expressions no longer perform any implicit type casting. 4. For BinaryOperators, convert NullType (i.e. null literals) into some accepted type so BinaryOperators don't need to handle NullTypes. TODOs needed: fix unit tests for error reporting. I'm intentionally not changing anything in aggregate expressions because yhuai is doing a big refactoring on that right now. Author: Reynold Xin <[email protected]> Closes apache#7348 from rxin/typecheck and squashes the following commits: 8fcf814 [Reynold Xin] Fixed ordering of cases. 3bb63e7 [Reynold Xin] Style fix. f45408f [Reynold Xin] Comment update. aa7790e [Reynold Xin] Moved RemoveNullTypes into ImplicitTypeCasts. 438ea07 [Reynold Xin] space d55c9e5 [Reynold Xin] Removes NullTypes. 360d124 [Reynold Xin] Fixed the rule. fb66657 [Reynold Xin] Convert NullType into some accepted type for BinaryOperators. 2e22330 [Reynold Xin] Fixed unit tests. 4932d57 [Reynold Xin] Style fix. d061691 [Reynold Xin] Rename existing ExpectsInputTypes -> ImplicitCastInputTypes. e4727cc [Reynold Xin] BinaryOperator should not be doing implicit cast. d017861 [Reynold Xin] Improve expression type checking.
1 parent f650a00 commit f23a721

File tree

17 files changed

+309
-165
lines changed

17 files changed

+309
-165
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.analysis
1919

20+
import scala.language.existentials
2021
import scala.reflect.ClassTag
2122
import scala.util.{Failure, Success, Try}
2223

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -214,19 +214,6 @@ object HiveTypeCoercion {
214214
}
215215

216216
Union(newLeft, newRight)
217-
218-
// Also widen types for BinaryOperator.
219-
case q: LogicalPlan => q transformExpressions {
220-
// Skip nodes who's children have not been resolved yet.
221-
case e if !e.childrenResolved => e
222-
223-
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
224-
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { widestType =>
225-
val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
226-
val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
227-
b.makeCopy(Array(newLeft, newRight))
228-
}.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
229-
}
230217
}
231218
}
232219

@@ -672,20 +659,44 @@ object HiveTypeCoercion {
672659
}
673660

674661
/**
675-
* Casts types according to the expected input types for Expressions that have the trait
676-
* [[ExpectsInputTypes]].
662+
* Casts types according to the expected input types for [[Expression]]s.
677663
*/
678664
object ImplicitTypeCasts extends Rule[LogicalPlan] {
679665
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
680666
// Skip nodes who's children have not been resolved yet.
681667
case e if !e.childrenResolved => e
682668

683-
case e: ExpectsInputTypes if (e.inputTypes.nonEmpty) =>
669+
case b @ BinaryOperator(left, right) if left.dataType != right.dataType =>
670+
findTightestCommonTypeOfTwo(left.dataType, right.dataType).map { commonType =>
671+
if (b.inputType.acceptsType(commonType)) {
672+
// If the expression accepts the tighest common type, cast to that.
673+
val newLeft = if (left.dataType == commonType) left else Cast(left, commonType)
674+
val newRight = if (right.dataType == commonType) right else Cast(right, commonType)
675+
b.makeCopy(Array(newLeft, newRight))
676+
} else {
677+
// Otherwise, don't do anything with the expression.
678+
b
679+
}
680+
}.getOrElse(b) // If there is no applicable conversion, leave expression unchanged.
681+
682+
case e: ImplicitCastInputTypes if e.inputTypes.nonEmpty =>
684683
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
685684
// If we cannot do the implicit cast, just use the original input.
686685
implicitCast(in, expected).getOrElse(in)
687686
}
688687
e.withNewChildren(children)
688+
689+
case e: ExpectsInputTypes if e.inputTypes.nonEmpty =>
690+
// Convert NullType into some specific target type for ExpectsInputTypes that don't do
691+
// general implicit casting.
692+
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
693+
if (in.dataType == NullType && !expected.acceptsType(NullType)) {
694+
Cast(in, expected.defaultConcreteType)
695+
} else {
696+
in
697+
}
698+
}
699+
e.withNewChildren(children)
689700
}
690701

691702
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,15 @@ package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2121
import org.apache.spark.sql.types.AbstractDataType
22-
22+
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.ImplicitTypeCasts
2323

2424
/**
2525
* An trait that gets mixin to define the expected input types of an expression.
26+
*
27+
* This trait is typically used by operator expressions (e.g. [[Add]], [[Subtract]]) to define
28+
* expected input types without any implicit casting.
29+
*
30+
* Most function expressions (e.g. [[Substring]] should extends [[ImplicitCastInputTypes]]) instead.
2631
*/
2732
trait ExpectsInputTypes { self: Expression =>
2833

@@ -40,7 +45,7 @@ trait ExpectsInputTypes { self: Expression =>
4045
val mismatches = children.zip(inputTypes).zipWithIndex.collect {
4146
case ((child, expected), idx) if !expected.acceptsType(child.dataType) =>
4247
s"argument ${idx + 1} is expected to be of type ${expected.simpleString}, " +
43-
s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}."
48+
s"however, '${child.prettyString}' is of type ${child.dataType.simpleString}."
4449
}
4550

4651
if (mismatches.isEmpty) {
@@ -50,3 +55,11 @@ trait ExpectsInputTypes { self: Expression =>
5055
}
5156
}
5257
}
58+
59+
60+
/**
61+
* A mixin for the analyzer to perform implicit type casting using [[ImplicitTypeCasts]].
62+
*/
63+
trait ImplicitCastInputTypes extends ExpectsInputTypes { self: Expression =>
64+
// No other methods
65+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,20 @@ import org.apache.spark.sql.catalyst.trees
2424
import org.apache.spark.sql.catalyst.trees.TreeNode
2525
import org.apache.spark.sql.types._
2626

27+
////////////////////////////////////////////////////////////////////////////////////////////////////
28+
// This file defines the basic expression abstract classes in Catalyst, including:
29+
// Expression: the base expression abstract class
30+
// LeafExpression
31+
// UnaryExpression
32+
// BinaryExpression
33+
// BinaryOperator
34+
//
35+
// For details, see their classdocs.
36+
////////////////////////////////////////////////////////////////////////////////////////////////////
2737

2838
/**
39+
* An expression in Catalyst.
40+
*
2941
* If an expression wants to be exposed in the function registry (so users can call it with
3042
* "name(arguments...)", the concrete implementation must be a case class whose constructor
3143
* arguments are all Expressions types.
@@ -335,15 +347,41 @@ abstract class BinaryExpression extends Expression with trees.BinaryNode[Express
335347

336348

337349
/**
338-
* An expression that has two inputs that are expected to the be same type. If the two inputs have
339-
* different types, the analyzer will find the tightest common type and do the proper type casting.
350+
* A [[BinaryExpression]] that is an operator, with two properties:
351+
*
352+
* 1. The string representation is "x symbol y", rather than "funcName(x, y)".
353+
* 2. Two inputs are expected to the be same type. If the two inputs have different types,
354+
* the analyzer will find the tightest common type and do the proper type casting.
340355
*/
341-
abstract class BinaryOperator extends BinaryExpression {
356+
abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
342357
self: Product =>
343358

359+
/**
360+
* Expected input type from both left/right child expressions, similar to the
361+
* [[ImplicitCastInputTypes]] trait.
362+
*/
363+
def inputType: AbstractDataType
364+
344365
def symbol: String
345366

346367
override def toString: String = s"($left $symbol $right)"
368+
369+
override def inputTypes: Seq[AbstractDataType] = Seq(inputType, inputType)
370+
371+
override def checkInputDataTypes(): TypeCheckResult = {
372+
// First call the checker for ExpectsInputTypes, and then check whether left and right have
373+
// the same type.
374+
super.checkInputDataTypes() match {
375+
case TypeCheckResult.TypeCheckSuccess =>
376+
if (left.dataType != right.dataType) {
377+
TypeCheckResult.TypeCheckFailure(s"differing types in '$prettyString' " +
378+
s"(${left.dataType.simpleString} and ${right.dataType.simpleString}).")
379+
} else {
380+
TypeCheckResult.TypeCheckSuccess
381+
}
382+
case TypeCheckResult.TypeCheckFailure(msg) => TypeCheckResult.TypeCheckFailure(msg)
383+
}
384+
}
347385
}
348386

349387

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ case class ScalaUDF(
2929
function: AnyRef,
3030
dataType: DataType,
3131
children: Seq[Expression],
32-
inputTypes: Seq[DataType] = Nil) extends Expression with ExpectsInputTypes {
32+
inputTypes: Seq[DataType] = Nil) extends Expression with ImplicitCastInputTypes {
3333

3434
override def nullable: Boolean = true
3535

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 37 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,19 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
21-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2221
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode}
2322
import org.apache.spark.sql.catalyst.util.TypeUtils
2423
import org.apache.spark.sql.types._
2524

26-
abstract class UnaryArithmetic extends UnaryExpression {
27-
self: Product =>
25+
26+
case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {
27+
28+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
2829

2930
override def dataType: DataType = child.dataType
30-
}
3131

32-
case class UnaryMinus(child: Expression) extends UnaryArithmetic {
3332
override def toString: String = s"-$child"
3433

35-
override def checkInputDataTypes(): TypeCheckResult =
36-
TypeUtils.checkForNumericExpr(child.dataType, "operator -")
37-
3834
private lazy val numeric = TypeUtils.getNumeric(dataType)
3935

4036
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = dataType match {
@@ -45,9 +41,13 @@ case class UnaryMinus(child: Expression) extends UnaryArithmetic {
4541
protected override def nullSafeEval(input: Any): Any = numeric.negate(input)
4642
}
4743

48-
case class UnaryPositive(child: Expression) extends UnaryArithmetic {
44+
case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
4945
override def prettyName: String = "positive"
5046

47+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
48+
49+
override def dataType: DataType = child.dataType
50+
5151
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String =
5252
defineCodeGen(ctx, ev, c => c)
5353

@@ -57,9 +57,11 @@ case class UnaryPositive(child: Expression) extends UnaryArithmetic {
5757
/**
5858
* A function that get the absolute value of the numeric value.
5959
*/
60-
case class Abs(child: Expression) extends UnaryArithmetic {
61-
override def checkInputDataTypes(): TypeCheckResult =
62-
TypeUtils.checkForNumericExpr(child.dataType, "function abs")
60+
case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes {
61+
62+
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
63+
64+
override def dataType: DataType = child.dataType
6365

6466
private lazy val numeric = TypeUtils.getNumeric(dataType)
6567

@@ -71,18 +73,6 @@ abstract class BinaryArithmetic extends BinaryOperator {
7173

7274
override def dataType: DataType = left.dataType
7375

74-
override def checkInputDataTypes(): TypeCheckResult = {
75-
if (left.dataType != right.dataType) {
76-
TypeCheckResult.TypeCheckFailure(
77-
s"differing types in ${this.getClass.getSimpleName} " +
78-
s"(${left.dataType} and ${right.dataType}).")
79-
} else {
80-
checkTypesInternal(dataType)
81-
}
82-
}
83-
84-
protected def checkTypesInternal(t: DataType): TypeCheckResult
85-
8676
/** Name of the function for this expression on a [[Decimal]] type. */
8777
def decimalMethod: String =
8878
sys.error("BinaryArithmetics must override either decimalMethod or genCode")
@@ -104,62 +94,61 @@ private[sql] object BinaryArithmetic {
10494
}
10595

10696
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
97+
98+
override def inputType: AbstractDataType = NumericType
99+
107100
override def symbol: String = "+"
108101
override def decimalMethod: String = "$plus"
109102

110103
override lazy val resolved =
111104
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
112105

113-
protected def checkTypesInternal(t: DataType) =
114-
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
115-
116106
private lazy val numeric = TypeUtils.getNumeric(dataType)
117107

118108
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.plus(input1, input2)
119109
}
120110

121111
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
112+
113+
override def inputType: AbstractDataType = NumericType
114+
122115
override def symbol: String = "-"
123116
override def decimalMethod: String = "$minus"
124117

125118
override lazy val resolved =
126119
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
127120

128-
protected def checkTypesInternal(t: DataType) =
129-
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
130-
131121
private lazy val numeric = TypeUtils.getNumeric(dataType)
132122

133123
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.minus(input1, input2)
134124
}
135125

136126
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
127+
128+
override def inputType: AbstractDataType = NumericType
129+
137130
override def symbol: String = "*"
138131
override def decimalMethod: String = "$times"
139132

140133
override lazy val resolved =
141134
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
142135

143-
protected def checkTypesInternal(t: DataType) =
144-
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
145-
146136
private lazy val numeric = TypeUtils.getNumeric(dataType)
147137

148138
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
149139
}
150140

151141
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
142+
143+
override def inputType: AbstractDataType = NumericType
144+
152145
override def symbol: String = "/"
153146
override def decimalMethod: String = "$div"
154-
155147
override def nullable: Boolean = true
156148

157149
override lazy val resolved =
158150
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
159151

160-
protected def checkTypesInternal(t: DataType) =
161-
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
162-
163152
private lazy val div: (Any, Any) => Any = dataType match {
164153
case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
165154
case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot
@@ -215,17 +204,16 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
215204
}
216205

217206
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
207+
208+
override def inputType: AbstractDataType = NumericType
209+
218210
override def symbol: String = "%"
219211
override def decimalMethod: String = "remainder"
220-
221212
override def nullable: Boolean = true
222213

223214
override lazy val resolved =
224215
childrenResolved && checkInputDataTypes().isSuccess && !DecimalType.isFixed(dataType)
225216

226-
protected def checkTypesInternal(t: DataType) =
227-
TypeUtils.checkForNumericExpr(t, "operator " + symbol)
228-
229217
private lazy val integral = dataType match {
230218
case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
231219
case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]]
@@ -281,10 +269,11 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
281269
}
282270

283271
case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
284-
override def nullable: Boolean = left.nullable && right.nullable
272+
// TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.
285273

286-
protected def checkTypesInternal(t: DataType) =
287-
TypeUtils.checkForOrderingExpr(t, "function maxOf")
274+
override def inputType: AbstractDataType = TypeCollection.Ordered
275+
276+
override def nullable: Boolean = left.nullable && right.nullable
288277

289278
private lazy val ordering = TypeUtils.getOrdering(dataType)
290279

@@ -335,10 +324,11 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
335324
}
336325

337326
case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
338-
override def nullable: Boolean = left.nullable && right.nullable
327+
// TODO: Remove MaxOf and MinOf, and replace its usage with Greatest and Least.
339328

340-
protected def checkTypesInternal(t: DataType) =
341-
TypeUtils.checkForOrderingExpr(t, "function minOf")
329+
override def inputType: AbstractDataType = TypeCollection.Ordered
330+
331+
override def nullable: Boolean = left.nullable && right.nullable
342332

343333
private lazy val ordering = TypeUtils.getOrdering(dataType)
344334

0 commit comments

Comments
 (0)