Skip to content

Commit 654d46a

Browse files
committed
improve tests
1 parent e0a3628 commit 654d46a

File tree

3 files changed

+67
-63
lines changed

3 files changed

+67
-63
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ abstract class BinaryArithmetic extends BinaryExpression {
9191
override def checkInputDataTypes(): TypeCheckResult = {
9292
if (left.dataType != right.dataType) {
9393
TypeCheckResult.fail(
94-
s"differing types in BinaryArithmetic, ${left.dataType} != ${right.dataType}")
94+
s"differing types in ${this.getClass.getSimpleName}, ${left.dataType} != ${right.dataType}")
9595
} else {
9696
checkTypesInternal(dataType)
9797
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ abstract class BinaryComparison extends BinaryExpression with Predicate {
175175
override def checkInputDataTypes(): TypeCheckResult = {
176176
if (left.dataType != right.dataType) {
177177
TypeCheckResult.fail(
178-
s"differing types in BinaryComparison, ${left.dataType} != ${right.dataType}")
178+
s"differing types in ${this.getClass.getSimpleName}, ${left.dataType} != ${right.dataType}")
179179
} else {
180180
TypeUtils.checkForOrderingExpr(left.dataType, "operator " + symbol)
181181
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala

Lines changed: 65 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -28,89 +28,93 @@ import org.scalatest.FunSuite
2828

2929
class ExpressionTypeCheckingSuite extends FunSuite {
3030

31-
val testRelation = LocalRelation('a.int, 'b.string, 'c.boolean, 'd.array(StringType))
31+
val testRelation = LocalRelation(
32+
'intField.int,
33+
'stringField.string,
34+
'booleanField.boolean,
35+
'complexField.array(StringType))
3236

33-
def checkError(expr: Expression, errorMessage: String): Unit = {
37+
def assertError(expr: Expression, errorMessage: String): Unit = {
3438
val e = intercept[AnalysisException] {
35-
checkAnalysis(expr)
39+
assertSuccess(expr)
3640
}
3741
assert(e.getMessage.contains(
3842
s"cannot resolve '${expr.prettyString}' due to data type mismatch:"))
3943
assert(e.getMessage.contains(errorMessage))
4044
}
4145

42-
def checkAnalysis(expr: Expression): Unit = {
43-
val analyzed = testRelation.select(expr.as("_c")).analyze
46+
def assertSuccess(expr: Expression): Unit = {
47+
val analyzed = testRelation.select(expr.as("c")).analyze
4448
SimpleAnalyzer.checkAnalysis(analyzed)
4549
}
4650

4751
test("check types for unary arithmetic") {
48-
checkError(UnaryMinus('b), "operator - accepts numeric type")
49-
checkAnalysis(Sqrt('b)) // We will cast String to Double for sqrt
50-
checkError(Sqrt('c), "function sqrt accepts numeric type")
51-
checkError(Abs('b), "function abs accepts numeric type")
52-
checkError(BitwiseNot('b), "operator ~ accepts integral type")
52+
assertError(UnaryMinus('stringField), "operator - accepts numeric type")
53+
assertSuccess(Sqrt('stringField)) // We will cast String to Double for sqrt
54+
assertError(Sqrt('booleanField), "function sqrt accepts numeric type")
55+
assertError(Abs('stringField), "function abs accepts numeric type")
56+
assertError(BitwiseNot('stringField), "operator ~ accepts integral type")
5357
}
5458

5559
test("check types for binary arithmetic") {
5660
// We will cast String to Double for binary arithmetic
57-
checkAnalysis(Add('a, 'b))
58-
checkAnalysis(Subtract('a, 'b))
59-
checkAnalysis(Multiply('a, 'b))
60-
checkAnalysis(Divide('a, 'b))
61-
checkAnalysis(Remainder('a, 'b))
62-
// checkAnalysis(BitwiseAnd('a, 'b))
63-
64-
val msg = "differing types in BinaryArithmetic, IntegerType != BooleanType"
65-
checkError(Add('a, 'c), msg)
66-
checkError(Subtract('a, 'c), msg)
67-
checkError(Multiply('a, 'c), msg)
68-
checkError(Divide('a, 'c), msg)
69-
checkError(Remainder('a, 'c), msg)
70-
checkError(BitwiseAnd('a, 'c), msg)
71-
checkError(BitwiseOr('a, 'c), msg)
72-
checkError(BitwiseXor('a, 'c), msg)
73-
checkError(MaxOf('a, 'c), msg)
74-
checkError(MinOf('a, 'c), msg)
75-
76-
checkError(Add('c, 'c), "operator + accepts numeric type")
77-
checkError(Subtract('c, 'c), "operator - accepts numeric type")
78-
checkError(Multiply('c, 'c), "operator * accepts numeric type")
79-
checkError(Divide('c, 'c), "operator / accepts numeric type")
80-
checkError(Remainder('c, 'c), "operator % accepts numeric type")
81-
82-
checkError(BitwiseAnd('c, 'c), "operator & accepts integral type")
83-
checkError(BitwiseOr('c, 'c), "operator | accepts integral type")
84-
checkError(BitwiseXor('c, 'c), "operator ^ accepts integral type")
85-
86-
checkError(MaxOf('d, 'd), "function maxOf accepts non-complex type")
87-
checkError(MinOf('d, 'd), "function minOf accepts non-complex type")
61+
assertSuccess(Add('intField, 'stringField))
62+
assertSuccess(Subtract('intField, 'stringField))
63+
assertSuccess(Multiply('intField, 'stringField))
64+
assertSuccess(Divide('intField, 'stringField))
65+
assertSuccess(Remainder('intField, 'stringField))
66+
// checkAnalysis(BitwiseAnd('intField, 'stringField))
67+
68+
def msg(caller: String) = s"differing types in $caller, IntegerType != BooleanType"
69+
assertError(Add('intField, 'booleanField), msg("Add"))
70+
assertError(Subtract('intField, 'booleanField), msg("Subtract"))
71+
assertError(Multiply('intField, 'booleanField), msg("Multiply"))
72+
assertError(Divide('intField, 'booleanField), msg("Divide"))
73+
assertError(Remainder('intField, 'booleanField), msg("Remainder"))
74+
assertError(BitwiseAnd('intField, 'booleanField), msg("BitwiseAnd"))
75+
assertError(BitwiseOr('intField, 'booleanField), msg("BitwiseOr"))
76+
assertError(BitwiseXor('intField, 'booleanField), msg("BitwiseXor"))
77+
assertError(MaxOf('intField, 'booleanField), msg("MaxOf"))
78+
assertError(MinOf('intField, 'booleanField), msg("MinOf"))
79+
80+
assertError(Add('booleanField, 'booleanField), "operator + accepts numeric type")
81+
assertError(Subtract('booleanField, 'booleanField), "operator - accepts numeric type")
82+
assertError(Multiply('booleanField, 'booleanField), "operator * accepts numeric type")
83+
assertError(Divide('booleanField, 'booleanField), "operator / accepts numeric type")
84+
assertError(Remainder('booleanField, 'booleanField), "operator % accepts numeric type")
85+
86+
assertError(BitwiseAnd('booleanField, 'booleanField), "operator & accepts integral type")
87+
assertError(BitwiseOr('booleanField, 'booleanField), "operator | accepts integral type")
88+
assertError(BitwiseXor('booleanField, 'booleanField), "operator ^ accepts integral type")
89+
90+
assertError(MaxOf('complexField, 'complexField), "function maxOf accepts non-complex type")
91+
assertError(MinOf('complexField, 'complexField), "function minOf accepts non-complex type")
8892
}
8993

9094
test("check types for predicates") {
9195
// EqualTo don't have type constraint
92-
checkAnalysis(EqualTo('a, 'c))
93-
checkAnalysis(EqualNullSafe('a, 'c))
96+
assertSuccess(EqualTo('intField, 'booleanField))
97+
assertSuccess(EqualNullSafe('intField, 'booleanField))
9498

9599
// We will cast String to Double for binary comparison
96-
checkAnalysis(LessThan('a, 'b))
97-
checkAnalysis(LessThanOrEqual('a, 'b))
98-
checkAnalysis(GreaterThan('a, 'b))
99-
checkAnalysis(GreaterThanOrEqual('a, 'b))
100-
101-
val msg = "differing types in BinaryComparison, IntegerType != BooleanType"
102-
checkError(LessThan('a, 'c), msg)
103-
checkError(LessThanOrEqual('a, 'c), msg)
104-
checkError(GreaterThan('a, 'c), msg)
105-
checkError(GreaterThanOrEqual('a, 'c), msg)
106-
107-
checkError(LessThan('d, 'd), "operator < accepts non-complex type")
108-
checkError(LessThanOrEqual('d, 'd), "operator <= accepts non-complex type")
109-
checkError(GreaterThan('d, 'd), "operator > accepts non-complex type")
110-
checkError(GreaterThanOrEqual('d, 'd), "operator >= accepts non-complex type")
111-
112-
checkError(If('a, 'a, 'a), "type of predicate expression in If should be boolean")
113-
checkError(If('c, 'a, 'b), "differing types in If, IntegerType != StringType")
100+
assertSuccess(LessThan('intField, 'stringField))
101+
assertSuccess(LessThanOrEqual('intField, 'stringField))
102+
assertSuccess(GreaterThan('intField, 'stringField))
103+
assertSuccess(GreaterThanOrEqual('intField, 'stringField))
104+
105+
def msg(caller: String) = s"differing types in $caller, IntegerType != BooleanType"
106+
assertError(LessThan('intField, 'booleanField), msg("LessThan"))
107+
assertError(LessThanOrEqual('intField, 'booleanField), msg("LessThanOrEqual"))
108+
assertError(GreaterThan('intField, 'booleanField), msg("GreaterThan"))
109+
assertError(GreaterThanOrEqual('intField, 'booleanField), msg("GreaterThanOrEqual"))
110+
111+
assertError(LessThan('complexField, 'complexField), "operator < accepts non-complex type")
112+
assertError(LessThanOrEqual('complexField, 'complexField), "operator <= accepts non-complex type")
113+
assertError(GreaterThan('complexField, 'complexField), "operator > accepts non-complex type")
114+
assertError(GreaterThanOrEqual('complexField, 'complexField), "operator >= accepts non-complex type")
115+
116+
assertError(If('intField, 'stringField, 'stringField), "type of predicate expression in If should be boolean")
117+
assertError(If('booleanField, 'intField, 'stringField), "differing types in If, IntegerType != StringType")
114118

115119
// Will write tests for CaseWhen later,
116120
// as the error reporting of it is not handle by the new interface for now

0 commit comments

Comments
 (0)