Skip to content

Commit acfa6ba

Browse files
committed
map type can not be used in EqualTo
1 parent 072f4c5 commit acfa6ba

File tree

4 files changed

+48
-43
lines changed

4 files changed

+48
-43
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -191,21 +191,6 @@ trait CheckAnalysis extends PredicateHelper {
191191
s"join condition '${condition.sql}' " +
192192
s"of type ${condition.dataType.simpleString} is not a boolean.")
193193

194-
case j @ Join(_, _, _, Some(condition)) =>
195-
def checkValidJoinConditionExprs(expr: Expression): Unit = expr match {
196-
case p: Predicate =>
197-
p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs)
198-
case e if e.dataType.isInstanceOf[BinaryType] =>
199-
failAnalysis(s"binary type expression ${e.sql} cannot be used " +
200-
"in join conditions")
201-
case e if e.dataType.isInstanceOf[MapType] =>
202-
failAnalysis(s"map type expression ${e.sql} cannot be used " +
203-
"in join conditions")
204-
case _ => // OK
205-
}
206-
207-
checkValidJoinConditionExprs(condition)
208-
209194
case Aggregate(groupingExprs, aggregateExprs, child) =>
210195
def checkValidAggregateExpression(expr: Expression): Unit = expr match {
211196
case aggExpr: AggregateExpression =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,21 @@ case class EqualTo(left: Expression, right: Expression)
415415

416416
override def inputType: AbstractDataType = AnyDataType
417417

418+
override def checkInputDataTypes(): TypeCheckResult = {
419+
super.checkInputDataTypes() match {
420+
case TypeCheckResult.TypeCheckSuccess =>
421+
// TODO: although map type is not orderable, technically map type should be able to be used
422+
// in equality comparison, remove this type check once we support it.
423+
if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) {
424+
TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualTo, but the actual " +
425+
s"input type is ${left.dataType.catalogString}.")
426+
} else {
427+
TypeCheckResult.TypeCheckSuccess
428+
}
429+
case failure => failure
430+
}
431+
}
432+
418433
override def symbol: String = "="
419434

420435
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
@@ -441,6 +456,21 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
441456

442457
override def inputType: AbstractDataType = AnyDataType
443458

459+
override def checkInputDataTypes(): TypeCheckResult = {
460+
super.checkInputDataTypes() match {
461+
case TypeCheckResult.TypeCheckSuccess =>
462+
// TODO: although map type is not orderable, technically map type should be able to be used
463+
// in equality comparison, remove this type check once we support it.
464+
if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) {
465+
TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualNullSafe, but the actual " +
466+
s"input type is ${left.dataType.catalogString}.")
467+
} else {
468+
TypeCheckResult.TypeCheckSuccess
469+
}
470+
case failure => failure
471+
}
472+
}
473+
444474
override def symbol: String = "<=>"
445475

446476
override def nullable: Boolean = false

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -451,34 +451,22 @@ class AnalysisErrorSuite extends AnalysisTest {
451451
"another aggregate function." :: Nil)
452452
}
453453

454-
test("Join can't work on binary and map types") {
455-
val plan =
456-
Join(
457-
LocalRelation(
458-
AttributeReference("a", BinaryType)(exprId = ExprId(2)),
459-
AttributeReference("b", IntegerType)(exprId = ExprId(1))),
460-
LocalRelation(
461-
AttributeReference("c", BinaryType)(exprId = ExprId(4)),
462-
AttributeReference("d", IntegerType)(exprId = ExprId(3))),
463-
Inner,
464-
Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)),
465-
AttributeReference("c", BinaryType)(exprId = ExprId(4)))))
466-
467-
assertAnalysisError(plan, "binary type expression `a` cannot be used in join conditions" :: Nil)
468-
469-
val plan2 =
470-
Join(
471-
LocalRelation(
472-
AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
473-
AttributeReference("b", IntegerType)(exprId = ExprId(1))),
474-
LocalRelation(
475-
AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)),
476-
AttributeReference("d", IntegerType)(exprId = ExprId(3))),
477-
Inner,
478-
Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)),
479-
AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)))))
480-
481-
assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil)
454+
test("Join can work on binary types but can't work on map types") {
455+
val left = LocalRelation('a.binary, 'b.map(StringType, StringType))
456+
val right = LocalRelation('c.binary, 'd.map(StringType, StringType))
457+
458+
val plan1 = left.join(
459+
right,
460+
joinType = Inner,
461+
condition = Some('a === 'c))
462+
463+
assertAnalysisSuccess(plan1)
464+
465+
val plan2 = left.join(
466+
right,
467+
joinType = Inner,
468+
condition = Some('b === 'd))
469+
assertAnalysisError(plan2, "Cannot use map type in EqualTo" :: Nil)
482470
}
483471

484472
test("PredicateSubQuery is used outside of a filter") {

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
118118
assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField))
119119
assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField))
120120

121+
assertError(EqualTo('mapField, 'mapField), "Cannot use map type in EqualTo")
122+
assertError(EqualNullSafe('mapField, 'mapField), "Cannot use map type in EqualNullSafe")
121123
assertError(LessThan('mapField, 'mapField),
122124
s"requires ${TypeCollection.Ordered.simpleString} type")
123125
assertError(LessThanOrEqual('mapField, 'mapField),

0 commit comments

Comments
 (0)