From b5a50f49027df71ae407634d6b0c609b23ed8f05 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 22 Nov 2016 20:19:58 +0800 Subject: [PATCH 1/2] map type can not be used in EqualTo --- .../sql/catalyst/analysis/CheckAnalysis.scala | 15 ------ .../sql/catalyst/expressions/predicates.scala | 46 +++++++++++++++++++ .../analysis/AnalysisErrorSuite.scala | 44 +++++++----------- .../ExpressionTypeCheckingSuite.scala | 2 + 4 files changed, 64 insertions(+), 43 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 98e50d0d3c67..80e577e5c4c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -183,21 +183,6 @@ trait CheckAnalysis extends PredicateHelper { s"join condition '${condition.sql}' " + s"of type ${condition.dataType.simpleString} is not a boolean.") - case j @ Join(_, _, _, Some(condition)) => - def checkValidJoinConditionExprs(expr: Expression): Unit = expr match { - case p: Predicate => - p.asInstanceOf[Expression].children.foreach(checkValidJoinConditionExprs) - case e if e.dataType.isInstanceOf[BinaryType] => - failAnalysis(s"binary type expression ${e.sql} cannot be used " + - "in join conditions") - case e if e.dataType.isInstanceOf[MapType] => - failAnalysis(s"map type expression ${e.sql} cannot be used " + - "in join conditions") - case _ => // OK - } - - checkValidJoinConditionExprs(condition) - case Aggregate(groupingExprs, aggregateExprs, child) => def checkValidAggregateExpression(expr: Expression): Unit = expr match { case aggExpr: AggregateExpression => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 7946c201f4ff..dbff57b60314 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -412,6 +412,29 @@ case class EqualTo(left: Expression, right: Expression) override def inputType: AbstractDataType = AnyDataType + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + // TODO: although map type is not orderable, technically map type should be able to be used + // in equality comparison, remove this type check once we support it. + if (hasMapType(left.dataType)) { + TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualTo, but the actual " + + s"input type is ${left.dataType.catalogString}.") + } else { + TypeCheckResult.TypeCheckSuccess + } + case failure => failure + } + } + + private def hasMapType(dt: DataType): Boolean = dt match { + case _: MapType => true + case st: StructType => st.map(_.dataType).exists(hasMapType) + case a: ArrayType => hasMapType(a.elementType) + case udt: UserDefinedType[_] => hasMapType(udt.sqlType) + case _ => false + } + override def symbol: String = "=" protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -440,6 +463,29 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp override def inputType: AbstractDataType = AnyDataType + override def checkInputDataTypes(): TypeCheckResult = { + super.checkInputDataTypes() match { + case TypeCheckResult.TypeCheckSuccess => + // TODO: although map type is not orderable, technically map type should be able to be used + // in equality comparison, remove this type check once we support it. + if (hasMapType(left.dataType)) { + TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualNullSafe, but the actual " + + s"input type is ${left.dataType.catalogString}.") + } else { + TypeCheckResult.TypeCheckSuccess + } + case failure => failure + } + } + + private def hasMapType(dt: DataType): Boolean = dt match { + case _: MapType => true + case st: StructType => st.map(_.dataType).exists(hasMapType) + case a: ArrayType => hasMapType(a.elementType) + case udt: UserDefinedType[_] => hasMapType(udt.sqlType) + case _ => false + } + override def symbol: String = "<=>" override def nullable: Boolean = false diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 21afe9fec594..8c1faea2394c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -465,34 +465,22 @@ class AnalysisErrorSuite extends AnalysisTest { "another aggregate function." :: Nil) } - test("Join can't work on binary and map types") { - val plan = - Join( - LocalRelation( - AttributeReference("a", BinaryType)(exprId = ExprId(2)), - AttributeReference("b", IntegerType)(exprId = ExprId(1))), - LocalRelation( - AttributeReference("c", BinaryType)(exprId = ExprId(4)), - AttributeReference("d", IntegerType)(exprId = ExprId(3))), - Cross, - Some(EqualTo(AttributeReference("a", BinaryType)(exprId = ExprId(2)), - AttributeReference("c", BinaryType)(exprId = ExprId(4))))) - - assertAnalysisError(plan, "binary type expression `a` cannot be used in join conditions" :: Nil) - - val plan2 = - Join( - LocalRelation( - AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), - AttributeReference("b", IntegerType)(exprId = ExprId(1))), - LocalRelation( - AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4)), - AttributeReference("d", IntegerType)(exprId = ExprId(3))), - Cross, - Some(EqualTo(AttributeReference("a", MapType(IntegerType, StringType))(exprId = ExprId(2)), - AttributeReference("c", MapType(IntegerType, StringType))(exprId = ExprId(4))))) - - assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil) + test("Join can work on binary types but can't work on map types") { + val left = LocalRelation('a.binary, 'b.map(StringType, StringType)) + val right = LocalRelation('c.binary, 'd.map(StringType, StringType)) + + val plan1 = left.join( + right, + joinType = Cross, + condition = Some('a === 'c)) + + assertAnalysisSuccess(plan1) + + val plan2 = left.join( + right, + joinType = Cross, + condition = Some('b === 'd)) + assertAnalysisError(plan2, "Cannot use map type in EqualTo" :: Nil) } test("PredicateSubQuery is used outside of a filter") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 542e654bbce1..744057b7c5f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -111,6 +111,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(GreaterThan('intField, 'booleanField)) assertErrorForDifferingTypes(GreaterThanOrEqual('intField, 'booleanField)) + assertError(EqualTo('mapField, 'mapField), "Cannot use map type in EqualTo") + assertError(EqualNullSafe('mapField, 'mapField), "Cannot use map type in EqualNullSafe") assertError(LessThan('mapField, 'mapField), s"requires ${TypeCollection.Ordered.simpleString} type") assertError(LessThanOrEqual('mapField, 'mapField), From dff0b0801026b2c047066c0ca7b4545ea676eef7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 22 Nov 2016 22:41:00 +0800 Subject: [PATCH 2/2] address comments --- .../sql/catalyst/expressions/predicates.scala | 20 ++----------------- 1 file changed, 2 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index dbff57b60314..2ad452b6a90c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -417,7 +417,7 @@ case class EqualTo(left: Expression, right: Expression) case TypeCheckResult.TypeCheckSuccess => // TODO: although map type is not orderable, technically map type should be able to be used // in equality comparison, remove this type check once we support it. - if (hasMapType(left.dataType)) { + if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) { TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualTo, but the actual " + s"input type is ${left.dataType.catalogString}.") } else { @@ -427,14 +427,6 @@ case class EqualTo(left: Expression, right: Expression) } } - private def hasMapType(dt: DataType): Boolean = dt match { - case _: MapType => true - case st: StructType => st.map(_.dataType).exists(hasMapType) - case a: ArrayType => hasMapType(a.elementType) - case udt: UserDefinedType[_] => hasMapType(udt.sqlType) - case _ => false - } - override def symbol: String = "=" protected override def nullSafeEval(input1: Any, input2: Any): Any = { @@ -468,7 +460,7 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp case TypeCheckResult.TypeCheckSuccess => // TODO: although map type is not orderable, technically map type should be able to be used // in equality comparison, remove this type check once we support it. - if (hasMapType(left.dataType)) { + if (left.dataType.existsRecursively(_.isInstanceOf[MapType])) { TypeCheckResult.TypeCheckFailure("Cannot use map type in EqualNullSafe, but the actual " + s"input type is ${left.dataType.catalogString}.") } else { @@ -478,14 +470,6 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } - private def hasMapType(dt: DataType): Boolean = dt match { - case _: MapType => true - case st: StructType => st.map(_.dataType).exists(hasMapType) - case a: ArrayType => hasMapType(a.elementType) - case udt: UserDefinedType[_] => hasMapType(udt.sqlType) - case _ => false - } - override def symbol: String = "<=>" override def nullable: Boolean = false