Skip to content

Commit 6491721

Browse files
committed
use value class TypeCheckResult
1 parent 7ae76b9 commit 6491721

File tree

5 files changed

+91
-56
lines changed

5 files changed

+91
-56
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ trait CheckAnalysis {
6262
val from = operator.inputSet.map(_.name).mkString(", ")
6363
a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
6464

65-
case e: Expression if e.checkInputDataTypes.isDefined =>
65+
case e: Expression if e.checkInputDataTypes.hasError =>
6666
e.failAnalysis(
6767
s"cannot resolve '${e.prettyString}' due to data type mismatch: " +
68-
e.checkInputDataTypes.get)
68+
e.checkInputDataTypes.errorMessage)
6969

7070
case c: Cast if !c.resolved =>
7171
failAnalysis(
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
/**
21+
* todo
22+
*/
23+
class TypeCheckResult(val errorMessage: String) extends AnyVal {
24+
def hasError: Boolean = errorMessage != null
25+
}
26+
27+
object TypeCheckResult {
28+
val success: TypeCheckResult = new TypeCheckResult(null)
29+
def fail(msg: String): TypeCheckResult = new TypeCheckResult(msg)
30+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
20+
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute}
2121
import org.apache.spark.sql.catalyst.trees
2222
import org.apache.spark.sql.catalyst.trees.TreeNode
2323
import org.apache.spark.sql.types._
@@ -90,7 +90,7 @@ abstract class Expression extends TreeNode[Expression] {
9090
/**
9191
* todo
9292
*/
93-
def checkInputDataTypes: Option[String] = None
93+
def checkInputDataTypes: TypeCheckResult = TypeCheckResult.success
9494
}
9595

9696
abstract class BinaryExpression extends Expression with trees.BinaryNode[Expression] {
@@ -134,9 +134,9 @@ trait ExpectsInputTypes {
134134

135135
def expectedChildTypes: Seq[DataType]
136136

137-
override def checkInputDataTypes: Option[String] = {
137+
override def checkInputDataTypes: TypeCheckResult = {
138138
// We will always do type casting for `ExpectsInputTypes` in `HiveTypeCoercion`,
139139
// so type mismatch error won't be reported here, but for underling `Cast`s.
140-
None
140+
TypeCheckResult.success
141141
}
142142
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2021
import org.apache.spark.sql.catalyst.util.TypeUtils
2122
import org.apache.spark.sql.types._
2223

@@ -27,11 +28,11 @@ abstract class UnaryArithmetic extends UnaryExpression {
2728
override def nullable: Boolean = child.nullable
2829
override def dataType: DataType = child.dataType
2930

30-
override def checkInputDataTypes: Option[String] = {
31+
override def checkInputDataTypes: TypeCheckResult = {
3132
if (TypeUtils.validForNumericExpr(child.dataType)) {
32-
None
33+
TypeCheckResult.success
3334
} else {
34-
Some("todo")
35+
TypeCheckResult.fail("todo")
3536
}
3637
}
3738

@@ -86,15 +87,16 @@ abstract class BinaryArithmetic extends BinaryExpression {
8687

8788
override def dataType: DataType = left.dataType
8889

89-
override def checkInputDataTypes: Option[String] = {
90+
override def checkInputDataTypes: TypeCheckResult = {
9091
if (left.dataType != right.dataType) {
91-
Some(s"differing types in BinaryArithmetics, ${left.dataType}, ${right.dataType}")
92+
TypeCheckResult.fail(
93+
s"differing types in BinaryArithmetics -- ${left.dataType}, ${right.dataType}")
9294
} else {
9395
checkTypesInternal(dataType)
9496
}
9597
}
9698

97-
protected def checkTypesInternal(t: DataType): Option[String]
99+
protected def checkTypesInternal(t: DataType): TypeCheckResult
98100

99101
override def eval(input: Row): Any = {
100102
val evalE1 = left.eval(input)
@@ -123,9 +125,9 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
123125

124126
protected def checkTypesInternal(t: DataType) = {
125127
if (TypeUtils.validForNumericExpr(t)) {
126-
None
128+
TypeCheckResult.success
127129
} else {
128-
Some("todo")
130+
TypeCheckResult.fail("todo")
129131
}
130132
}
131133

@@ -143,9 +145,9 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
143145

144146
protected def checkTypesInternal(t: DataType) = {
145147
if (TypeUtils.validForNumericExpr(t)) {
146-
None
148+
TypeCheckResult.success
147149
} else {
148-
Some("todo")
150+
TypeCheckResult.fail("todo")
149151
}
150152
}
151153

@@ -163,9 +165,9 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
163165

164166
protected def checkTypesInternal(t: DataType) = {
165167
if (TypeUtils.validForNumericExpr(t)) {
166-
None
168+
TypeCheckResult.success
167169
} else {
168-
Some("todo")
170+
TypeCheckResult.fail("todo")
169171
}
170172
}
171173

@@ -184,9 +186,9 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
184186

185187
protected def checkTypesInternal(t: DataType) = {
186188
if (TypeUtils.validForNumericExpr(t)) {
187-
None
189+
TypeCheckResult.success
188190
} else {
189-
Some("todo")
191+
TypeCheckResult.fail("todo")
190192
}
191193
}
192194

@@ -220,9 +222,9 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
220222

221223
protected def checkTypesInternal(t: DataType) = {
222224
if (TypeUtils.validForNumericExpr(t)) {
223-
None
225+
TypeCheckResult.success
224226
} else {
225-
Some("todo")
227+
TypeCheckResult.fail("todo")
226228
}
227229
}
228230

@@ -254,9 +256,9 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
254256

255257
protected def checkTypesInternal(t: DataType) = {
256258
if (TypeUtils.validForBitwiseExpr(t)) {
257-
None
259+
TypeCheckResult.success
258260
} else {
259-
Some("todo")
261+
TypeCheckResult.fail("todo")
260262
}
261263
}
262264

@@ -282,9 +284,9 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
282284

283285
protected def checkTypesInternal(t: DataType) = {
284286
if (TypeUtils.validForBitwiseExpr(t)) {
285-
None
287+
TypeCheckResult.success
286288
} else {
287-
Some("todo")
289+
TypeCheckResult.fail("todo")
288290
}
289291
}
290292

@@ -310,9 +312,9 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
310312

311313
protected def checkTypesInternal(t: DataType) = {
312314
if (TypeUtils.validForBitwiseExpr(t)) {
313-
None
315+
TypeCheckResult.success
314316
} else {
315-
Some("todo")
317+
TypeCheckResult.fail("todo")
316318
}
317319
}
318320

@@ -336,11 +338,11 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme
336338
case class BitwiseNot(child: Expression) extends UnaryArithmetic {
337339
override def toString: String = s"~$child"
338340

339-
override def checkInputDataTypes: Option[String] = {
341+
override def checkInputDataTypes: TypeCheckResult = {
340342
if (TypeUtils.validForBitwiseExpr(dataType)) {
341-
None
343+
TypeCheckResult.success
342344
} else {
343-
Some("todo")
345+
TypeCheckResult.fail("todo")
344346
}
345347
}
346348

@@ -363,9 +365,9 @@ case class MaxOf(left: Expression, right: Expression) extends BinaryArithmetic {
363365

364366
protected def checkTypesInternal(t: DataType) = {
365367
if (TypeUtils.validForOrderingExpr(t)) {
366-
None
368+
TypeCheckResult.success
367369
} else {
368-
Some("todo")
370+
TypeCheckResult.fail("todo")
369371
}
370372
}
371373

@@ -395,9 +397,9 @@ case class MinOf(left: Expression, right: Expression) extends BinaryArithmetic {
395397

396398
protected def checkTypesInternal(t: DataType) = {
397399
if (TypeUtils.validForOrderingExpr(t)) {
398-
None
400+
TypeCheckResult.success
399401
} else {
400-
Some("todo")
402+
TypeCheckResult.fail("todo")
401403
}
402404
}
403405

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2021
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2122
import org.apache.spark.sql.catalyst.util.TypeUtils
2223
import org.apache.spark.sql.types.{DecimalType, BinaryType, BooleanType, DataType}
@@ -171,15 +172,16 @@ case class Or(left: Expression, right: Expression)
171172
abstract class BinaryComparison extends BinaryExpression with Predicate {
172173
self: Product =>
173174

174-
override def checkInputDataTypes: Option[String] = {
175+
override def checkInputDataTypes: TypeCheckResult = {
175176
if (left.dataType != right.dataType) {
176-
Some(s"differing types in BinaryComparisons, ${left.dataType}, ${right.dataType}")
177+
TypeCheckResult.fail(
178+
s"differing types in BinaryComparisons -- ${left.dataType}, ${right.dataType}")
177179
} else {
178180
checkTypesInternal(left.dataType)
179181
}
180182
}
181183

182-
protected def checkTypesInternal(t: DataType): Option[String] = None
184+
protected def checkTypesInternal(t: DataType): TypeCheckResult = TypeCheckResult.success
183185

184186
override def eval(input: Row): Any = {
185187
val evalE1 = left.eval(input)
@@ -231,9 +233,9 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
231233

232234
override protected def checkTypesInternal(t: DataType) = {
233235
if (TypeUtils.validForOrderingExpr(t)) {
234-
None
236+
TypeCheckResult.success
235237
} else {
236-
Some("todo")
238+
TypeCheckResult.fail("todo")
237239
}
238240
}
239241

@@ -247,9 +249,9 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
247249

248250
override protected def checkTypesInternal(t: DataType) = {
249251
if (TypeUtils.validForOrderingExpr(t)) {
250-
None
252+
TypeCheckResult.success
251253
} else {
252-
Some("todo")
254+
TypeCheckResult.fail("todo")
253255
}
254256
}
255257

@@ -263,9 +265,9 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar
263265

264266
override protected def checkTypesInternal(t: DataType) = {
265267
if (TypeUtils.validForOrderingExpr(t)) {
266-
None
268+
TypeCheckResult.success
267269
} else {
268-
Some("todo")
270+
TypeCheckResult.fail("todo")
269271
}
270272
}
271273

@@ -279,9 +281,9 @@ case class GreaterThanOrEqual(left: Expression, right: Expression) extends Binar
279281

280282
override protected def checkTypesInternal(t: DataType) = {
281283
if (TypeUtils.validForOrderingExpr(t)) {
282-
None
284+
TypeCheckResult.success
283285
} else {
284-
Some("todo")
286+
TypeCheckResult.fail("todo")
285287
}
286288
}
287289

@@ -296,11 +298,12 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
296298
override def children: Seq[Expression] = predicate :: trueValue :: falseValue :: Nil
297299
override def nullable: Boolean = trueValue.nullable || falseValue.nullable
298300

299-
override def checkInputDataTypes: Option[String] = {
301+
override def checkInputDataTypes: TypeCheckResult = {
300302
if (trueValue.dataType != falseValue.dataType) {
301-
Some(s"differing types in If, ${trueValue.dataType}, ${falseValue.dataType}")
303+
TypeCheckResult.fail(
304+
s"differing types in If, ${trueValue.dataType}, ${falseValue.dataType}")
302305
} else {
303-
None
306+
TypeCheckResult.success
304307
}
305308
}
306309

@@ -357,13 +360,13 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
357360

358361
override def children: Seq[Expression] = branches
359362

360-
override def checkInputDataTypes: Option[String] = {
363+
override def checkInputDataTypes: TypeCheckResult = {
361364
if (!whenList.forall(_.dataType == BooleanType)) {
362-
Some(s"WHEN expressions should all be boolean type")
365+
TypeCheckResult.fail(s"WHEN expressions should all be boolean type")
363366
} else if (!valueTypesEqual) {
364-
Some("THEN and ELSE expressions should all be same type")
367+
TypeCheckResult.fail("THEN and ELSE expressions should all be same type")
365368
} else {
366-
None
369+
TypeCheckResult.success
367370
}
368371
}
369372

@@ -408,11 +411,11 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
408411

409412
override def children: Seq[Expression] = key +: branches
410413

411-
override def checkInputDataTypes: Option[String] = {
414+
override def checkInputDataTypes: TypeCheckResult = {
412415
if (!valueTypesEqual) {
413-
Some("THEN and ELSE expressions should all be same type")
416+
TypeCheckResult.fail("THEN and ELSE expressions should all be same type")
414417
} else {
415-
None
418+
TypeCheckResult.success
416419
}
417420
}
418421

0 commit comments

Comments
 (0)