Skip to content

Commit 8883025

Browse files
committed
apply type check interface to CaseWhen
1 parent cffb67c commit 8883025

File tree

3 files changed

+48
-38
lines changed

3 files changed

+48
-38
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -599,9 +599,8 @@ trait HiveTypeCoercion {
599599
// from the list. So we need to make sure the return type is deterministic and
600600
// compatible with every child column.
601601
case Coalesce(es) if es.map(_.dataType).distinct.size > 1 =>
602-
val dt: Option[DataType] = Some(NullType)
603602
val types = es.map(_.dataType)
604-
val rt = types.foldLeft(dt)((r, c) => r match {
603+
val rt = types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
605604
case None => None
606605
case Some(d) => findTightestCommonType(d, c)
607606
})
@@ -635,28 +634,30 @@ trait HiveTypeCoercion {
635634
* Coerces the type of different branches of a CASE WHEN statement to a common type.
636635
*/
637636
object CaseWhenCoercion extends Rule[LogicalPlan] {
637+
638638
import HiveTypeCoercion._
639639

640640
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
641-
case cw: CaseWhenLike if cw.childrenResolved && !cw.valueTypesEqual =>
641+
case cw: CaseWhenLike if cw.childrenResolved && cw.checkInputDataTypes().hasError =>
642642
logDebug(s"Input values for null casting ${cw.valueTypes.mkString(",")}")
643-
val commonType = cw.valueTypes.reduce { (v1, v2) =>
644-
findTightestCommonType(v1, v2).getOrElse(sys.error(
645-
s"Types in CASE WHEN must be the same or coercible to a common type: $v1 != $v2"))
646-
}
647-
val transformedBranches = cw.branches.sliding(2, 2).map {
648-
case Seq(when, value) if value.dataType != commonType =>
649-
Seq(when, Cast(value, commonType))
650-
case Seq(elseVal) if elseVal.dataType != commonType =>
651-
Seq(Cast(elseVal, commonType))
652-
case s => s
653-
}.reduce(_ ++ _)
654-
cw match {
655-
case _: CaseWhen =>
656-
CaseWhen(transformedBranches)
657-
case CaseKeyWhen(key, _) =>
658-
CaseKeyWhen(key, transformedBranches)
659-
}
643+
cw.valueTypes.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match {
644+
case None => None
645+
case Some(d) => findTightestCommonType(d, c)
646+
}).map { commonType =>
647+
val transformedBranches = cw.branches.sliding(2, 2).map {
648+
case Seq(when, value) if value.dataType != commonType =>
649+
Seq(when, Cast(value, commonType))
650+
case Seq(elseVal) if elseVal.dataType != commonType =>
651+
Seq(Cast(elseVal, commonType))
652+
case s => s
653+
}.reduce(_ ++ _)
654+
cw match {
655+
case _: CaseWhen =>
656+
CaseWhen(transformedBranches)
657+
case CaseKeyWhen(key, _) =>
658+
CaseKeyWhen(key, transformedBranches)
659+
}
660+
}.getOrElse(cw)
660661

661662
case ckw: CaseKeyWhen if ckw.childrenResolved && !ckw.resolved =>
662663
val commonType = (ckw.key +: ckw.whenList).map(_.dataType).reduce { (v1, v2) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,18 @@ trait CaseWhenLike extends Expression {
325325
def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType)
326326
def valueTypesEqual: Boolean = valueTypes.distinct.size == 1
327327

328-
override def dataType: DataType = valueTypes.head
328+
override def checkInputDataTypes(): TypeCheckResult = {
329+
if (valueTypes.distinct.size > 1) {
330+
TypeCheckResult.fail(
331+
"THEN and ELSE expressions should all be same type or coercible to a common type")
332+
} else {
333+
checkTypesInternal()
334+
}
335+
}
336+
337+
protected def checkTypesInternal(): TypeCheckResult
338+
339+
override def dataType: DataType = thenList.head.dataType
329340

330341
override def nullable: Boolean = {
331342
// If no value is nullable and no elseValue is provided, the whole statement defaults to null.
@@ -347,14 +358,11 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike {
347358

348359
override def children: Seq[Expression] = branches
349360

350-
override def checkInputDataTypes(): TypeCheckResult = {
351-
if (!whenList.forall(_.dataType == BooleanType)) {
352-
TypeCheckResult.fail(s"WHEN expressions should all be boolean type")
353-
} else if (!valueTypesEqual) {
354-
TypeCheckResult.fail(
355-
"THEN and ELSE expressions should all be same type or coercible to a common type")
356-
} else {
361+
override protected def checkTypesInternal(): TypeCheckResult = {
362+
if (whenList.forall(_.dataType == BooleanType)) {
357363
TypeCheckResult.success
364+
} else {
365+
TypeCheckResult.fail(s"WHEN expressions in CaseWhen should all be boolean type")
358366
}
359367
}
360368

@@ -399,14 +407,7 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
399407

400408
override def children: Seq[Expression] = key +: branches
401409

402-
override def checkInputDataTypes(): TypeCheckResult = {
403-
if (!valueTypesEqual) {
404-
TypeCheckResult.fail(
405-
"THEN and ELSE expressions should all be same type or coercible to a common type")
406-
} else {
407-
TypeCheckResult.success
408-
}
409-
}
410+
override protected def checkTypesInternal(): TypeCheckResult = TypeCheckResult.success
410411

411412
/** Written in imperative fashion for performance considerations. */
412413
override def eval(input: Row): Any = {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionTypeCheckingSuite.scala

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,15 @@ class ExpressionTypeCheckingSuite extends FunSuite {
126126
"type of predicate expression in If should be boolean")
127127
assertErrorForDifferingTypes(If('booleanField, 'intField, 'booleanField))
128128

129-
// Will write tests for CaseWhen later,
130-
// as the error reporting of it is not handle by the new interface for now
129+
assertError(
130+
CaseWhen(Seq('booleanField, 'intField, 'booleanField, 'complexField)),
131+
"THEN and ELSE expressions should all be same type or coercible to a common type")
132+
assertError(
133+
CaseKeyWhen('intField, Seq('intField, 'stringField, 'intField, 'complexField)),
134+
"THEN and ELSE expressions should all be same type or coercible to a common type")
135+
assertError(
136+
CaseWhen(Seq('booleanField, 'intField, 'intField, 'intField)),
137+
"WHEN expressions in CaseWhen should all be boolean type")
138+
131139
}
132140
}

0 commit comments

Comments
 (0)