Skip to content

Commit d38cf21

Browse files
cloud-fanrxin
authored andcommitted
[SPARK-7562][SPARK-6444][SQL] Improve error reporting for expression data type mismatch
It seems hard to find a common pattern of checking types in `Expression`. Sometimes we know what input types we need(like `And`, we know we need two booleans), sometimes we just have some rules(like `Add`, we need 2 numeric types which are equal). So I defined a general interface `checkInputDataTypes` in `Expression` which returns a `TypeCheckResult`. `TypeCheckResult` can tell whether this expression passes the type checking or what the type mismatch is. This PR mainly works on apply input types checking for arithmetic and predicate expressions. TODO: apply type checking interface to more expressions. Author: Wenchen Fan <[email protected]> Closes apache#6405 from cloud-fan/6444 and squashes the following commits: b5ff31b [Wenchen Fan] address comments b917275 [Wenchen Fan] rebase 39929d9 [Wenchen Fan] add todo 0808fd2 [Wenchen Fan] make constrcutor of TypeCheckResult private 3bee157 [Wenchen Fan] and decimal type coercion rule for binary comparison 8883025 [Wenchen Fan] apply type check interface to CaseWhen cffb67c [Wenchen Fan] to have resolved call the data type check function 6eaadff [Wenchen Fan] add equal type constraint to EqualTo 3affbd8 [Wenchen Fan] more fixes 654d46a [Wenchen Fan] improve tests e0a3628 [Wenchen Fan] improve error message 1524ff6 [Wenchen Fan] fix style 69ca3fe [Wenchen Fan] add error message and tests c71d02c [Wenchen Fan] fix hive tests 6491721 [Wenchen Fan] use value class TypeCheckResult 7ae76b9 [Wenchen Fan] address comments cb77e4f [Wenchen Fan] Improve error reporting for expression data type mismatch
1 parent ce320cb commit d38cf21

File tree

17 files changed

+583
-421
lines changed

17 files changed

+583
-421
lines changed

core/src/test/scala/org/apache/spark/SparkFunSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ private[spark] abstract class SparkFunSuite extends FunSuite with Logging {
3030
* Log the suite name and the test name before and after each test.
3131
*
3232
* Subclasses should never override this method. If they wish to run
33-
* custom code before and after each test, they should should mix in
34-
* the {{org.scalatest.BeforeAndAfter}} trait instead.
33+
* custom code before and after each test, they should mix in the
34+
* {{org.scalatest.BeforeAndAfter}} trait instead.
3535
*/
3636
final protected override def withFixture(test: NoArgTest): Outcome = {
3737
val testName = test.text

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,17 @@ trait CheckAnalysis {
6262
val from = operator.inputSet.map(_.name).mkString(", ")
6363
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
6464

65+
case e: Expression if e.checkInputDataTypes().isFailure =>
66+
e.checkInputDataTypes() match {
67+
case TypeCheckResult.TypeCheckFailure(message) =>
68+
e.failAnalysis(
69+
s"cannot resolve '${e.prettyString}' due to data type mismatch: $message")
70+
}
71+
6572
case c: Cast if !c.resolved =>
6673
failAnalysis(
6774
s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}")
6875

69-
case b: BinaryExpression if !b.resolved =>
70-
failAnalysis(
71-
s"invalid expression ${b.prettyString} " +
72-
s"between ${b.left.dataType.simpleString} and ${b.right.dataType.simpleString}")
73-
7476
case WindowExpression(UnresolvedWindowFunction(name, _), _) =>
7577
failAnalysis(
7678
s"Could not resolve window function '$name'. " +

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 61 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ object HiveTypeCoercion {
4141
* with primitive types, because in that case the precision and scale of the result depends on
4242
* the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]].
4343
*/
44-
val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
44+
val findTightestCommonTypeOfTwo: (DataType, DataType) => Option[DataType] = {
4545
case (t1, t2) if t1 == t2 => Some(t1)
4646
case (NullType, t1) => Some(t1)
4747
case (t1, NullType) => Some(t1)
@@ -57,6 +57,17 @@ object HiveTypeCoercion {
5757

5858
case _ => None
5959
}
60+
61+
/**
62+
* Find the tightest common type of a set of types by continuously applying
63+
* `findTightestCommonTypeOfTwo` on these types.
64+
*/
65+
private def findTightestCommonType(types: Seq[DataType]) = {
66+
types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
67+
case None => None
68+
case Some(d) => findTightestCommonTypeOfTwo(d, c)
69+
})
70+
}
6071
}
6172

6273
/**
@@ -180,7 +191,7 @@ trait HiveTypeCoercion {
180191

181192
case (l, r) if l.dataType != r.dataType =>
182193
logDebug(s"Resolving mismatched union input ${l.dataType}, ${r.dataType}")
183-
findTightestCommonType(l.dataType, r.dataType).map { widestType =>
194+
findTightestCommonTypeOfTwo(l.dataType, r.dataType).map { widestType =>
184195
val newLeft =
185196
if (l.dataType == widestType) l else Alias(Cast(l, widestType), l.name)()
186197
val newRight =
@@ -217,7 +228,7 @@ trait HiveTypeCoercion {
217228
case e if !e.childrenResolved => e
218229

219230
case b: BinaryExpression if b.left.dataType != b.right.dataType =>
220-
findTightestCommonType(b.left.dataType, b.right.dataType).map { widestType =>
231+
findTightestCommonTypeOfTwo(b.left.dataType, b.right.dataType).map { widestType =>
221232
val newLeft =
222233
if (b.left.dataType == widestType) b.left else Cast(b.left, widestType)
223234
val newRight =
@@ -441,21 +452,18 @@ trait HiveTypeCoercion {
441452
DecimalType(min(p1 - s1, p2 - s2) + max(s1, s2), max(s1, s2))
442453
)
443454

444-
case LessThan(e1 @ DecimalType.Expression(p1, s1),
445-
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
446-
LessThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
447-
448-
case LessThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
449-
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
450-
LessThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
451-
452-
case GreaterThan(e1 @ DecimalType.Expression(p1, s1),
453-
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
454-
GreaterThan(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
455-
456-
case GreaterThanOrEqual(e1 @ DecimalType.Expression(p1, s1),
457-
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
458-
GreaterThanOrEqual(Cast(e1, DecimalType.Unlimited), Cast(e2, DecimalType.Unlimited))
455+
// When we compare 2 decimal types with different precisions, cast them to the smallest
456+
// common precision.
457+
case b @ BinaryComparison(e1 @ DecimalType.Expression(p1, s1),
458+
e2 @ DecimalType.Expression(p2, s2)) if p1 != p2 || s1 != s2 =>
459+
val resultType = DecimalType(max(p1, p2), max(s1, s2))
460+
b.makeCopy(Array(Cast(e1, resultType), Cast(e2, resultType)))
461+
case b @ BinaryComparison(e1 @ DecimalType.Fixed(_, _), e2)
462+
if e2.dataType == DecimalType.Unlimited =>
463+
b.makeCopy(Array(Cast(e1, DecimalType.Unlimited), e2))
464+
case b @ BinaryComparison(e1, e2 @ DecimalType.Fixed(_, _))
465+
if e1.dataType == DecimalType.Unlimited =>
466+
b.makeCopy(Array(e1, Cast(e2, DecimalType.Unlimited)))
459467

460468
// Promote integers inside a binary expression with fixed-precision decimals to decimals,
461469
// and fixed-precision decimals in an expression with floats / doubles to doubles
@@ -570,7 +578,7 @@ trait HiveTypeCoercion {
570578

571579
case a @ CreateArray(children) if !a.resolved =>
572580
val commonType = a.childTypes.reduce(
573-
(a, b) => findTightestCommonType(a, b).getOrElse(StringType))
581+
(a, b) => findTightestCommonTypeOfTwo(a, b).getOrElse(StringType))
574582
CreateArray(
575583
children.map(c => if (c.dataType == commonType) c else Cast(c, commonType)))
576584

@@ -599,14 +607,9 @@ trait HiveTypeCoercion {
599607
// from the list. So we need to make sure the return type is deterministic and
600608
// compatible with every child column.
601609
case Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
602-
val dt: Option[DataType] = Some(NullType)
603610
val types = es.map(_.dataType)
604-
val rt = types.foldLeft(dt)((r, c) => r match {
605-
case None => None
606-
case Some(d) => findTightestCommonType(d, c)
607-
})
608-
rt match {
609-
case Some(finaldt) => Coalesce(es.map(Cast(_, finaldt)))
611+
findTightestCommonType(types) match {
612+
case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType)))
610613
case None =>
611614
sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}")
612615
}
@@ -619,17 +622,13 @@ trait HiveTypeCoercion {
619622
*/
620623
object Division extends Rule[LogicalPlan] {
621624
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
622-
// Skip nodes who's children have not been resolved yet.
623-
case e if !e.childrenResolved => e
625+
// Skip nodes who has not been resolved yet,
626+
// as this is an extra rule which should be applied at last.
627+
case e if !e.resolved => e
624628

625629
// Decimal and Double remain the same
626-
case d: Divide if d.resolved && d.dataType == DoubleType => d
627-
case d: Divide if d.resolved && d.dataType.isInstanceOf[DecimalType] => d
628-
629-
case Divide(l, r) if l.dataType.isInstanceOf[DecimalType] =>
630-
Divide(l, Cast(r, DecimalType.Unlimited))
631-
case Divide(l, r) if r.dataType.isInstanceOf[DecimalType] =>
632-
Divide(Cast(l, DecimalType.Unlimited), r)
630+
case d: Divide if d.dataType == DoubleType => d
631+
case d: Divide if d.dataType.isInstanceOf[DecimalType] => d
633632

634633
case Divide(l, r) => Divide(Cast(l, DoubleType), Cast(r, DoubleType))
635634
}
@@ -642,42 +641,33 @@ trait HiveTypeCoercion {
642641
import HiveTypeCoercion._
643642

644643
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
645-
case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual =>
646-
logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}")
647-
val commonType = cw.valueTypes.reduce { (v1, v2) =>
648-
findTightestCommonType(v1, v2).getOrElse(sys.error(
649-
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
650-
}
651-
val transformedBranches = cw.branches.sliding(2, 2).map {
652-
case Seq(when, value) if value.dataType != commonType =>
653-
Seq(when, Cast(value, commonType))
654-
case Seq(elseVal) if elseVal.dataType != commonType =>
655-
Seq(Cast(elseVal, commonType))
656-
case s => s
657-
}.reduce(_ ++ _)
658-
cw match {
659-
case _: CaseWhen =>
660-
CaseWhen(transformedBranches)
661-
case CaseKeyWhen(key, _) =>
662-
CaseKeyWhen(key, transformedBranches)
663-
}
664-
665-
case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved =>
666-
val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) =>
667-
findTightestCommonType(v1, v2).getOrElse(sys.error(
668-
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
669-
}
670-
val transformedBranches = ckw.branches.sliding(2, 2).map {
671-
case Seq(when, then) if when.dataType != commonType =>
672-
Seq(Cast(when, commonType), then)
673-
case s => s
674-
}.reduce(_ ++ _)
675-
val transformedKey = if (ckw.key.dataType != commonType) {
676-
Cast(ckw.key, commonType)
677-
} else {
678-
ckw.key
679-
}
680-
CaseKeyWhen(transformedKey, transformedBranches)
644+
case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual =>
645+
logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}")
646+
val maybeCommonType = findTightestCommonType(c.valueTypes)
647+
maybeCommonType.map { commonType =>
648+
val castedBranches = c.branches.grouped(2).map {
649+
case Seq(when, value) if value.dataType != commonType =>
650+
Seq(when, Cast(value, commonType))
651+
case Seq(elseVal) if elseVal.dataType != commonType =>
652+
Seq(Cast(elseVal, commonType))
653+
case other => other
654+
}.reduce(_ ++ _)
655+
c match {
656+
case _: CaseWhen => CaseWhen(castedBranches)
657+
case CaseKeyWhen(key, _) => CaseKeyWhen(key, castedBranches)
658+
}
659+
}.getOrElse(c)
660+
661+
case c: CaseKeyWhen if c.childrenResolved && !c.resolved =>
662+
val maybeCommonType = findTightestCommonType((c.key +: c.whenList).map(_.dataType))
663+
maybeCommonType.map { commonType =>
664+
val castedBranches = c.branches.grouped(2).map {
665+
case Seq(when, then) if when.dataType != commonType =>
666+
Seq(Cast(when, commonType), then)
667+
case other => other
668+
}.reduce(_ ++ _)
669+
CaseKeyWhen(Cast(c.key, commonType), castedBranches)
670+
}.getOrElse(c)
681671
}
682672
}
683673

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
/**
21+
* Represents the result of `Expression.checkInputDataTypes`.
22+
* We will throw `AnalysisException` in `CheckAnalysis` if `isFailure` is true.
23+
*/
24+
trait TypeCheckResult {
25+
def isFailure: Boolean = !isSuccess
26+
def isSuccess: Boolean
27+
}
28+
29+
object TypeCheckResult {
30+
31+
/**
32+
* Represents the successful result of `Expression.checkInputDataTypes`.
33+
*/
34+
object TypeCheckSuccess extends TypeCheckResult {
35+
def isSuccess: Boolean = true
36+
}
37+
38+
/**
39+
* Represents the failing result of `Expression.checkInputDataTypes`,
40+
* with a error message to show the reason of failure.
41+
*/
42+
case class TypeCheckFailure(message: String) extends TypeCheckResult {
43+
def isSuccess: Boolean = false
44+
}
45+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
20+
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
2121
import org.apache.spark.sql.catalyst.trees
2222
import org.apache.spark.sql.catalyst.trees.TreeNode
2323
import org.apache.spark.sql.types._
@@ -53,11 +53,12 @@ abstract class Expression extends TreeNode[Expression] {
5353

5454
/**
5555
* Returns `true` if this expression and all its children have been resolved to a specific schema
56-
* and `false` if it still contains any unresolved placeholders. Implementations of expressions
57-
* should override this if the resolution of this type of expression involves more than just
58-
* the resolution of its children.
56+
* and input data types checking passed, and `false` if it still contains any unresolved
57+
* placeholders or has data types mismatch.
58+
* Implementations of expressions should override this if the resolution of this type of
59+
* expression involves more than just the resolution of its children and type checking.
5960
*/
60-
lazy val resolved: Boolean = childrenResolved
61+
lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess
6162

6263
/**
6364
* Returns the [[DataType]] of the result of evaluating this expression. It is
@@ -94,12 +95,21 @@ abstract class Expression extends TreeNode[Expression] {
9495
case (i1, i2) => i1 == i2
9596
}
9697
}
98+
99+
/**
100+
* Checks the input data types, returns `TypeCheckResult.success` if it's valid,
101+
* or returns a `TypeCheckResult` with an error message if invalid.
102+
* Note: it's not valid to call this method until `childrenResolved == true`
103+
* TODO: we should remove the default implementation and implement it for all
104+
* expressions with proper error message.
105+
*/
106+
def checkInputDataTypes(): TypeCheckResult = TypeCheckResult.TypeCheckSuccess
97107
}
98108

99109
abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
100110
self: Product =>
101111

102-
def symbol: String
112+
def symbol: String = sys.error(s"BinaryExpressions must override either toString or symbol")
103113

104114
override def foldable: Boolean = left.foldable && right.foldable
105115

@@ -133,7 +143,13 @@ case class GroupExpression(children: Seq[Expression]) extends Expression {
133143
* so that the proper type conversions can be performed in the analyzer.
134144
*/
135145
trait ExpectsInputTypes {
146+
self: Expression =>
136147

137148
def expectedChildTypes: Seq[DataType]
138149

150+
override def checkInputDataTypes(): TypeCheckResult = {
151+
// We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`,
152+
// so type mismatch error won't be reported here, but for underling `Cast`s.
153+
TypeCheckResult.TypeCheckSuccess
154+
}
139155
}

0 commit comments

Comments
 (0)