From a98908032316007aa26bfad3d6ef0f42ef523666 Mon Sep 17 00:00:00 2001 From: SaurabhChawla Date: Fri, 6 Aug 2021 11:48:51 +0530 Subject: [PATCH 1/3] Add the support for Maptype in the Group by in spark Sql --- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../sql/catalyst/expressions/ordering.scala | 9 +++++-- .../optimizer/NormalizeFloatingNumbers.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 20 +++++++++++---- .../sql/execution/aggregate/AggUtils.scala | 25 +++++++++++++------ 5 files changed, 41 insertions(+), 17 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 043bf9594327b..fcb79a9d50e45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -318,7 +318,7 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { } // Check if the data type of expr is orderable. - if (!RowOrdering.isOrderable(expr.dataType)) { + if (!RowOrdering.isOrderable(expr.dataType, isGroupingExpr = true)) { failAnalysis( s"expression ${expr.sql} cannot be used as a grouping expression " + s"because its data type ${expr.dataType.catalogString} is not an orderable " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index ba3ed02e06ef1..ef32bb10c1c33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -97,13 +97,18 @@ object InterpretedOrdering { object RowOrdering extends CodeGeneratorWithInterpretedFallback[Seq[SortOrder], BaseOrdering] { /** - * Returns true iff the data type can be ordered (i.e. can be sorted). + * Returns true if the data type can be ordered (i.e. can be sorted). */ - def isOrderable(dataType: DataType): Boolean = dataType match { + def isOrderable(dataType: DataType, + isGroupingExpr: Boolean = false): Boolean = dataType match { case NullType => true case dt: AtomicType => true case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType)) case array: ArrayType => isOrderable(array.elementType) + // Support MapType when the request comes from check + // analysis for the grouping expression + case map: MapType if isGroupingExpr => + isOrderable(map.keyType) && isOrderable(map.valueType) case udt: UserDefinedType[_] => isOrderable(udt.sqlType) case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index 6d0f46baa0984..46a7f286bee62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -92,7 +92,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case _ => needNormalize(expr.dataType) } - private def needNormalize(dt: DataType): Boolean = dt match { + private[sql] def needNormalize(dt: DataType): Boolean = dt match { case FloatType | DoubleType => true case StructType(fields) => fields.exists(f => needNormalize(f.dataType)) case ArrayType(et, _) => needNormalize(et) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 6ebfba2c02957..7f92f1a739e64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemoryPlan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{MapType, StructType} /** * Converts a logical plan into zero or more SparkPlans. This API is exposed for experimenting @@ -498,10 +498,20 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because // `groupingExpressions` is not extracted during logical phase. val normalizedGroupingExpressions = groupingExpressions.map { e => - NormalizeFloatingNumbers.normalize(e) match { - case n: NamedExpression => n - // Keep the name of the original expression. - case other => Alias(other, e.name)(exprId = e.exprId) + e.dataType match { + // Support use of MapType in the group by when aggregateExpressions + // does not contain the MapType attribute and both keys and value + // are not Float/Double. + case MapType(kt, vt, _) + if !aggregateExpressions.exists(_.references == e.references) && + !NormalizeFloatingNumbers.needNormalize(kt) && + !NormalizeFloatingNumbers.needNormalize(vt) => e + case _ => + NormalizeFloatingNumbers.normalize(e) match { + case n: NamedExpression => n + // Keep the name of the original expression. + case other => Alias(other, e.name)(exprId = e.exprId) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 0f239b457fd14..1879959b161b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.types.MapType /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -76,14 +77,22 @@ object AggUtils { resultExpressions = resultExpressions, child = child) } else { - SortAggregateExec( - requiredChildDistributionExpressions = requiredChildDistributionExpressions, - groupingExpressions = groupingExpressions, - aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), - aggregateAttributes = aggregateAttributes, - initialInputBufferOffset = initialInputBufferOffset, - resultExpressions = resultExpressions, - child = child) + // In SortAggregateExec there is one step that checks whether + // expression datatype is orderable or not over there Map + // is not orderable, Adding the validation for checking + // the Maptype in grouping expression + if (!groupingExpressions.exists(_.dataType.isInstanceOf[MapType])) { + SortAggregateExec( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = mayRemoveAggFilters(aggregateExpressions), + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } else { + throw new IllegalStateException("grouping keys cannot be map type for SortAggregateExec") + } } } } From e6505d154b7187fd364224cfda4cbe2bd5430d2a Mon Sep 17 00:00:00 2001 From: SaurabhChawla Date: Sun, 8 Aug 2021 18:19:55 +0530 Subject: [PATCH 2/3] add the unit test for the map column in group by --- .../analysis/AnalysisErrorSuite.scala | 2 +- .../spark/sql/DataFrameAggregateSuite.scala | 27 +++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 6cda05360aea3..cd2a886e3a341 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -588,6 +588,7 @@ class AnalysisErrorSuite extends AnalysisTest { FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), + MapType(StringType, LongType), new StructType() .add("f1", FloatType, nullable = true) .add("f2", StringType, nullable = true), @@ -600,7 +601,6 @@ class AnalysisErrorSuite extends AnalysisTest { } val unsupportedDataTypes = Seq( - MapType(StringType, LongType), new StructType() .add("f1", FloatType, nullable = true) .add("f2", MapType(StringType, LongType), nullable = true), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index d0a122e0fe094..d7c8bad423abe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1427,6 +1427,33 @@ class DataFrameAggregateSuite extends QueryTest assert (df.schema == expectedSchema) checkAnswer(df, Seq(Row(LocalDateTime.parse(ts1), 2), Row(LocalDateTime.parse(ts2), 1))) } + + test("SPARK-36452: Support Map Type column in group by") { + var df = Seq((1, Map(1 -> 2)), (2, Map(1 -> 2))).toDF("id", "mapInfo") + // group by map column + checkAnswer(df.groupBy("mapInfo").count(), Seq(Row(Map[Any, Any](1 -> 2), 2))) + // group by map column and other column + checkAnswer(df.groupBy("id", "mapInfo").count(), + Seq(Row(1, Map[Any, Any](1 -> 2), 1), Row(2, Map[Any, Any](1 -> 2), 1))) + checkAnswer(df.groupBy("mapInfo").agg(avg("id")), + Seq(Row(Map[Any, Any](1 -> 2), 1.5))) + // Does not support if the map type if present in the aggregated expression + var error = intercept[IllegalStateException] { + df.groupBy("mapInfo").agg(max(map_keys(col("mapinfo")))).collect + } + assert(error.getMessage.contains("grouping/join/window partition keys cannot be map type.")) + // Does not support if the map type with float/double keys or value + df = Seq((1, Map(1 -> 2.0)), (2, Map(1 -> 2.0))).toDF("id", "mapInfo") + error = intercept[IllegalStateException] { + df.groupBy("mapInfo").agg(max(map_keys(col("mapinfo")))).collect + } + assert(error.getMessage.contains("grouping/join/window partition keys cannot be map type.")) + df = Seq((1, Map(1.1 -> 2.0)), (2, Map(1.1 -> 2.0))).toDF("id", "mapInfo") + error = intercept[IllegalStateException] { + df.groupBy("mapInfo").agg(max(map_keys(col("mapinfo")))).collect + } + assert(error.getMessage.contains("grouping/join/window partition keys cannot be map type.")) + } } case class B(c: Option[Double]) From a09e37faec1be04a20785846f9a2b53c8fdfd663 Mon Sep 17 00:00:00 2001 From: SaurabhChawla Date: Mon, 9 Aug 2021 11:19:33 +0530 Subject: [PATCH 3/3] update the comment --- .../org/apache/spark/sql/catalyst/expressions/ordering.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index ef32bb10c1c33..7a6bd588ad699 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -97,7 +97,7 @@ object InterpretedOrdering { object RowOrdering extends CodeGeneratorWithInterpretedFallback[Seq[SortOrder], BaseOrdering] { /** - * Returns true if the data type can be ordered (i.e. can be sorted). + * Returns true iff the data type can be ordered (i.e. can be sorted). */ def isOrderable(dataType: DataType, isGroupingExpr: Boolean = false): Boolean = dataType match {