From c91f4417a17e51fe4e5d47fa196695fe4179d518 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Mon, 8 Apr 2024 13:49:01 +0200 Subject: [PATCH 1/9] Draft commit --- .../analysis/RewriteGroupByCollation.scala | 101 ++++++++++++++ .../sql/catalyst/expressions/ExprUtils.scala | 8 +- .../optimizer/MergeScalarSubqueries.scala | 9 +- .../sql/catalyst/optimizer/Optimizer.scala | 3 +- .../plans/logical/basicLogicalOperators.scala | 6 +- .../sql/execution/aggregate/AggUtils.scala | 2 +- .../aggregate/HashAggregateExec.scala | 40 +++++- .../org/apache/spark/sql/CollationSuite.scala | 124 +++++++++++++++--- 8 files changed, 258 insertions(+), 35 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala new file mode 100644 index 0000000000000..ce79ee9ae5753 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala @@ -0,0 +1,101 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.util.Locale + +// import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExpectsInputTypes, Expression, StringTypeAnyCollation, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AnyValue +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +// import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * A rule that rewrites aggregation operations using plans that operate on collated strings. + */ +object RewriteGroupByCollation extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput { + case a: Aggregate => + val aliasMap = a.groupingExpressions.collect { + case attr: AttributeReference if attr.dataType.isInstanceOf[StringType] => + attr -> CollationKey(attr) // Alias(CollationKey(attr), attr.name)() + }.toMap + + val newGroupingExpressions = a.groupingExpressions.map { + case attr: AttributeReference if aliasMap.contains(attr) => + aliasMap(attr) + case other => other + } + + val newAggregateExpressions = a.aggregateExpressions.map { + case attr: AttributeReference if aliasMap.contains(attr) => + Alias(AnyValue(attr, ignoreNulls = false), attr.name)() + case other => other + } + + val newAggregate = a.copy( + groupingExpressions = newGroupingExpressions, + aggregateExpressions = newAggregateExpressions + ) + + (newAggregate, a.output.zip(newAggregate.output)) + } +} + +case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def dataType: DataType = expr.dataType + + final lazy val collationId: Int = dataType.asInstanceOf[StringType].collationId + + override def nullSafeEval(input: Any): Any = { + val str: UTF8String = input.asInstanceOf[UTF8String] + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { + str + } else if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { + UTF8String.fromString(str.toString.toLowerCase(Locale.ROOT)) + } else { + val collator = CollationFactory.fetchCollation(collationId).collator + val collationKey = collator.getCollationKey(str.toString) + UTF8String.fromBytes(collationKey.toByteArray) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { + defineCodeGen(ctx, ev, c => s"$c") + } else if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { + defineCodeGen(ctx, ev, c => s"$c.toLowerCase()") + } else { + defineCodeGen(ctx, ev, c => s"UTF8String.fromBytes(CollationFactory.fetchCollation" + + s"($collationId).collator.getCollationKey($c.toString()).toByteArray())") + } + } + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(expr = newChild) + } + + override def child: Expression = expr +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 258bc0ed8fe73..9642e8440bc48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -168,7 +168,13 @@ object ExprUtils extends QueryErrorsBase { a.failAnalysis( errorClass = "MISSING_GROUP_BY", messageParameters = Map.empty) - case e: Attribute if !a.groupingExpressions.exists(_.semanticEquals(e)) => + case e: Attribute if !a.groupingExpressions.exists(_.semanticEquals(e)) && + !a.groupingExpressions.exists { + case al: Alias => al.child.semanticEquals(CollationKey(e)) + case ck: CollationKey => ck.semanticEquals(CollationKey(e)) + case other => other.semanticEquals(e) + } + => throw QueryCompilationErrors.columnNotInGroupByClauseError(e) case s: ScalarSubquery if s.children.nonEmpty && !a.groupingExpressions.exists(_.semanticEquals(s)) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala index 2d1e71a63a8ce..24f10854f7dfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala @@ -353,15 +353,10 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { case a: AggregateExpression => a }) } - val groupByExpressionSeq = Seq(newPlan, cachedPlan).map(_.groupingExpressions) val Seq(newPlanSupportsHashAggregate, cachedPlanSupportsHashAggregate) = - aggregateExpressionsSeq.zip(groupByExpressionSeq).map { - case (aggregateExpressions, groupByExpressions) => - Aggregate.supportsHashAggregate( - aggregateExpressions.flatMap( - _.aggregateFunction.aggBufferAttributes), groupByExpressions) - } + aggregateExpressionsSeq.map(aggregateExpressions => Aggregate.supportsHashAggregate( + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))) newPlanSupportsHashAggregate && cachedPlanSupportsHashAggregate || newPlanSupportsHashAggregate == cachedPlanSupportsHashAggregate && { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 122cba5b74f8d..abff866fb3ce1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -131,7 +131,8 @@ abstract class Optimizer(catalogManager: CatalogManager) SimplifyExtractValueOps, OptimizeCsvJsonExprs, CombineConcats, - PushdownPredicatesAndPruneColumnsForCTEDef) ++ + PushdownPredicatesAndPruneColumnsForCTEDef, + RewriteGroupByCollation) ++ extendedOperatorOptimizationRules val operatorOptimizationBatch: Seq[Batch] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 1c8f7a97dd7fe..7c2dfd31f4e33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1232,11 +1232,9 @@ object Aggregate { schema.forall(f => UnsafeRow.isMutable(f.dataType)) } - def supportsHashAggregate( - aggregateBufferAttributes: Seq[Attribute], groupingExpression: Seq[Expression]): Boolean = { + def supportsHashAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = { val aggregationBufferSchema = DataTypeUtils.fromAttributes(aggregateBufferAttributes) - isAggregateBufferMutable(aggregationBufferSchema) && - groupingExpression.forall(e => UnsafeRowUtils.isBinaryStable(e.dataType)) + isAggregateBufferMutable(aggregationBufferSchema) } def supportsObjectHashAggregate(aggregateExpressions: Seq[AggregateExpression]): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 367d4cfafb485..278c1fc3f73b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -76,7 +76,7 @@ object AggUtils { resultExpressions: Seq[NamedExpression] = Nil, child: SparkPlan): SparkPlan = { val useHash = Aggregate.supportsHashAggregate( - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes), groupingExpressions) + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) val forceObjHashAggregate = forceApplyObjectHashAggregate(child.conf) val forceSortAggregate = forceApplySortAggregate(child.conf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index bdf17607d77c5..bf55e370cd83e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.aggregate +import java.util.Locale import java.util.concurrent.TimeUnit._ import scala.collection.mutable @@ -33,17 +34,19 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.catalyst.util.{truncatedString, CollationFactory} import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS -import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{CalendarIntervalType, DecimalType, StringType} +import org.apache.spark.sql.types.{CalendarIntervalType, DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils + /** * Hash-based aggregate operator that can also fallback to sorting when data exceeds memory size. */ @@ -59,7 +62,7 @@ case class HashAggregateExec( child: SparkPlan) extends AggregateCodegenSupport { - require(Aggregate.supportsHashAggregate(aggregateBufferAttributes, groupingExpressions)) + require(Aggregate.supportsHashAggregate(aggregateBufferAttributes)) override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ @@ -87,6 +90,31 @@ case class HashAggregateExec( } } + private def collationAwareStringRows(row: InternalRow, schema: StructType): InternalRow = { + val newRow = new Array[Any](schema.length) + for (i <- schema.fields.indices) { + val field = schema.fields(i) + field.dataType match { + case st: StringType => + val str: UTF8String = row.getUTF8String(i) + val collationId: Int = st.collationId + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { + newRow(i) = str + } else if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { + newRow(i) = UTF8String.fromString(str.toString.toLowerCase(Locale.ROOT)) + } else { + val collator = CollationFactory.fetchCollation(collationId).collator + val collationKey = collator.getCollationKey(str.toString).toByteArray + newRow(i) = UTF8String.fromString(collationKey.map("%02x" format _).mkString) + } + case _ => + newRow(i) = row.get(i, field.dataType) + } + } + val project = UnsafeProjection.create(schema) + project(InternalRow.fromSeq(newRow.toIndexedSeq)) + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val peakMemory = longMetric("peakMemory") @@ -96,9 +124,9 @@ case class HashAggregateExec( val numTasksFallBacked = longMetric("numTasksFallBacked") child.execute().mapPartitionsWithIndex { (partIndex, iter) => - +// val collationAwareIterator = iter.map(row => collationAwareStringRows(row, child.schema)) val beforeAgg = System.nanoTime() - val hasInput = iter.hasNext + val hasInput = iter.hasNext // collationAwareIterator.hasNext val res = if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input iterator is empty, // so return an empty iterator. @@ -115,7 +143,7 @@ case class HashAggregateExec( (expressions, inputSchema) => MutableProjection.create(expressions, inputSchema), inputAttributes, - iter, + iter, // collationAwareIterator testFallbackStartsAt, numOutputRows, peakMemory, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 99c0dbfcb1448..fea96018b0e0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -21,7 +21,12 @@ import scala.collection.immutable.Seq import scala.jdk.CollectionConverters.MapHasAsJava import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} +import org.apache.spark.sql.catalyst.analysis.{CollationKey, RewriteGroupByCollation} +import org.apache.spark.sql.catalyst.expressions.{BindReferences, UnsafeProjection} +// import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema} import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} @@ -33,6 +38,7 @@ import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAg import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName @@ -369,20 +375,10 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { test("aggregates count respects collation") { Seq( - ("utf8_binary", Seq("AAA", "aaa"), Seq(Row(1, "AAA"), Row(1, "aaa"))), - ("utf8_binary", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), - ("utf8_binary", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), - ("utf8_binary_lcase", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), - ("utf8_binary_lcase", Seq("AAA", "aaa"), Seq(Row(2, "AAA"))), - ("utf8_binary_lcase", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), - ("unicode", Seq("AAA", "aaa"), Seq(Row(1, "AAA"), Row(1, "aaa"))), - ("unicode", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), - ("unicode", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), - ("unicode_CI", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), - ("unicode_CI", Seq("AAA", "aaa"), Seq(Row(2, "AAA"))), - ("unicode_CI", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))) + ("unicode_ci", Seq("AA", "aa"), Seq(Row(2, "aa"))), ).foreach { case (collationName: String, input: Seq[String], expected: Seq[Row]) => + spark.conf.set("spark.sql.codegen.wholeStage", "false") checkAnswer(sql( s""" with t as ( @@ -395,7 +391,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("hash agg is not used for non binary collations") { + test("hash agg is also used for non binary collations") { val tableNameNonBinary = "T_NON_BINARY" val tableNameBinary = "T_BINARY" withTable(tableNameNonBinary) { @@ -408,7 +404,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val dfNonBinary = sql(s"SELECT COUNT(*), c FROM $tableNameNonBinary GROUP BY c") assert(collectFirst(dfNonBinary.queryExecution.executedPlan) { case _: HashAggregateExec | _: ObjectHashAggregateExec => () - }.isEmpty) + }.nonEmpty) val dfBinary = sql(s"SELECT COUNT(*), c FROM $tableNameBinary GROUP BY c") assert(collectFirst(dfBinary.queryExecution.executedPlan) { @@ -819,4 +815,102 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkAnswer(dfNonBinary, dfBinary) } } + + test("CollationKey generates correct collation key") { + val testCases = Seq( + ("", "UTF8_BINARY", UTF8String.fromString("")), + ("aa", "UTF8_BINARY", UTF8String.fromString("aa")), + ("AA", "UTF8_BINARY", UTF8String.fromString("AA")), + ("aA", "UTF8_BINARY", UTF8String.fromString("aA")), + ("", "UTF8_BINARY_LCASE", UTF8String.fromString("")), + ("aa", "UTF8_BINARY_LCASE", UTF8String.fromString("aa")), + ("AA", "UTF8_BINARY_LCASE", UTF8String.fromString("aa")), + ("aA", "UTF8_BINARY_LCASE", UTF8String.fromString("aa")), + ("", "UNICODE", UTF8String.fromBytes(Array[Byte](1, 1, 0))), + ("aa", "UNICODE", UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 1, 6, 0))), + ("AA", "UNICODE", UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 1, -36, -36, 0))), + ("aA", "UNICODE", UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 1, -59, -36, 0))), + ("", "UNICODE_CI", UTF8String.fromBytes(Array[Byte](1, 0))), + ("aa", "UNICODE_CI", UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 0))), + ("AA", "UNICODE_CI", UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 0))), + ("aA", "UNICODE_CI", UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 0))) + ) + for ((input, collation, expected) <- testCases) { + val collationId: Int = CollationFactory.collationNameToId(collation) + val attrRef: AttributeReference = AttributeReference("attr", StringType(collationId))() + // generate CollationKey for the input string + val collationKey: CollationKey = CollationKey(attrRef) + val str: UTF8String = UTF8String.fromString(input) + assert(collationKey.nullSafeEval(str) === expected) + } + } + + test("CollationKey generates correct collation key using codegen") { + val testCases = Seq( + ("", "UTF8_BINARY", ""), + ("aa", "UTF8_BINARY", "6161"), + ("AA", "UTF8_BINARY", "4141"), + ("aA", "UTF8_BINARY", "4161"), + ("", "UTF8_BINARY_LCASE", ""), + ("aa", "UTF8_BINARY_LCASE", "6161"), + ("AA", "UTF8_BINARY_LCASE", "6161"), + ("aA", "UTF8_BINARY_LCASE", "6161"), + ("", "UNICODE", "101"), + ("aa", "UNICODE", "60106012a2a"), + ("AA", "UNICODE", "dcdc0106012a2a"), + ("aA", "UNICODE", "dcc50106012a2a"), + ("", "UNICODE_CI", "1"), + ("aa", "UNICODE_CI", "6012a2a"), + ("AA", "UNICODE_CI", "6012a2a"), + ("aA", "UNICODE_CI", "6012a2a"), + ) + for ((input, collation, expected) <- testCases) { + val collationId: Int = CollationFactory.collationNameToId(collation) + val attrRef: AttributeReference = AttributeReference("attr", StringType(collationId))() + // generate CollationKey for the input string + val collationKey: CollationKey = CollationKey(attrRef) + val str: UTF8String = UTF8String.fromString(input) + val boundExpr = BindReferences.bindReference(collationKey, Seq(attrRef)) + val ev = UnsafeProjection.create(Array(boundExpr).toIndexedSeq) + val strProj = ev.apply(InternalRow(str)) + assert(strProj.toString.split(',').last.startsWith(expected)) + } + } + + test("RewriteGroupByCollation rule rewrites Aggregate logical plan") { + val dataType = StringType(CollationFactory.collationNameToId("UNICODE_CI")) + val attrRef = AttributeReference("attr", dataType)() + val originalPlan = Aggregate(Seq(attrRef), Seq(attrRef), LocalRelation(attrRef)) + assert(originalPlan.groupingExpressions.size == 1) + assert(originalPlan.groupingExpressions.head == attrRef) + // plan level rewrite should put CollationKey in Aggregate logical plan + val newPlan = RewriteGroupByCollation(originalPlan) + val groupingExpressions = newPlan.asInstanceOf[Aggregate].groupingExpressions + assert(groupingExpressions.size == 1) // only 1 alias should be present in groupingExpressions + val groupingAlias = groupingExpressions.head.asInstanceOf[Alias] + assert(groupingAlias.child.isInstanceOf[CollationKey]) // alias should be a CollationKey + assert(groupingAlias.child.containsChild(attrRef)) // CollationKey should be for attrRef + } + + test("RewriteGroupByCollation rule works in SQL query analysis") { + spark.conf.set("spark.sql.codegen.wholeStage", value = false) + val dataType = StringType(CollationFactory.collationNameToId("UNICODE_CI")) + val schema = StructType(Seq(StructField("name", dataType))) + val data = Seq(Row("AA"), Row("aa"), Row("BB")) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.createOrReplaceTempView("tempTable") + val dfGroupBy = spark.sql("SELECT name, COUNT(*) FROM tempTable GROUP BY name") + // get the logical plan for the spark SQL query + val logicalPlan = dfGroupBy.queryExecution.analyzed + val newPlan = RewriteGroupByCollation(logicalPlan) + assert(newPlan.isInstanceOf[Aggregate]) + val groupingExpressions = newPlan.asInstanceOf[Aggregate].groupingExpressions + assert(groupingExpressions.size == 1) +// val groupingAlias = groupingExpressions.head.asInstanceOf[Alias] +// assert(groupingAlias.isInstanceOf[Alias]) +// assert(groupingAlias.child.isInstanceOf[CollationKey]) + // get the query execution result + checkAnswer(dfGroupBy, Seq(Row("AA", 2), Row("BB", 1))) + } + } From 76728ac64bd0010c0b5454f49c9349646b23269f Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Tue, 9 Apr 2024 07:32:52 +0200 Subject: [PATCH 2/9] Preserve idempotence --- .../spark/sql/catalyst/analysis/RewriteGroupByCollation.scala | 4 ++-- .../org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 4 ++-- .../sql/catalyst/optimizer/PullOutGroupingExpressions.scala | 3 ++- .../src/test/scala/org/apache/spark/sql/CollationSuite.scala | 2 ++ 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala index ce79ee9ae5753..333d9a7605086 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala @@ -22,7 +22,7 @@ import java.util.Locale // import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExpectsInputTypes, Expression, StringTypeAnyCollation, UnaryExpression} -import org.apache.spark.sql.catalyst.expressions.aggregate.AnyValue +import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} // import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} @@ -50,7 +50,7 @@ object RewriteGroupByCollation extends Rule[LogicalPlan] { val newAggregateExpressions = a.aggregateExpressions.map { case attr: AttributeReference if aliasMap.contains(attr) => - Alias(AnyValue(attr, ignoreNulls = false), attr.name)() + Alias(First(attr, ignoreNulls = false), attr.name)() case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index abff866fb3ce1..2acf9cf9e4c71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -131,8 +131,7 @@ abstract class Optimizer(catalogManager: CatalogManager) SimplifyExtractValueOps, OptimizeCsvJsonExprs, CombineConcats, - PushdownPredicatesAndPruneColumnsForCTEDef, - RewriteGroupByCollation) ++ + PushdownPredicatesAndPruneColumnsForCTEDef) ++ extendedOperatorOptimizationRules val operatorOptimizationBatch: Seq[Batch] = { @@ -295,6 +294,7 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateView, ReplaceExpressions, RewriteNonCorrelatedExists, + RewriteGroupByCollation, PullOutGroupingExpressions, ComputeCurrentTime, ReplaceCurrentLike(catalogManager), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutGroupingExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutGroupingExpressions.scala index ecc3619d584f5..13738963dd171 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutGroupingExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutGroupingExpressions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable +import org.apache.spark.sql.catalyst.analysis.CollationKey import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} @@ -51,7 +52,7 @@ object PullOutGroupingExpressions extends Rule[LogicalPlan] { case a: Aggregate if a.resolved => val complexGroupingExpressionMap = mutable.LinkedHashMap.empty[Expression, NamedExpression] val newGroupingExpressions = a.groupingExpressions.toIndexedSeq.map { - case e if !e.foldable && e.children.nonEmpty => + case e if !e.foldable && e.children.nonEmpty && !e.isInstanceOf[CollationKey] => complexGroupingExpressionMap .getOrElseUpdate(e.canonicalized, Alias(e, "_groupingexpression")()) .toAttribute diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index fea96018b0e0a..474642e2b525a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -903,7 +903,9 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { // get the logical plan for the spark SQL query val logicalPlan = dfGroupBy.queryExecution.analyzed val newPlan = RewriteGroupByCollation(logicalPlan) + val newNewPlan = RewriteGroupByCollation(newPlan) assert(newPlan.isInstanceOf[Aggregate]) + assert(newNewPlan.isInstanceOf[Aggregate]) val groupingExpressions = newPlan.asInstanceOf[Aggregate].groupingExpressions assert(groupingExpressions.size == 1) // val groupingAlias = groupingExpressions.head.asInstanceOf[Alias] From b4733c351d40dbc6ca3e8fef18bd159e30a11e0a Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Tue, 9 Apr 2024 16:52:30 +0200 Subject: [PATCH 3/9] Fix agg expressions --- .../spark/sql/catalyst/analysis/RewriteGroupByCollation.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala index 333d9a7605086..7d6d5254f2735 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala @@ -50,7 +50,7 @@ object RewriteGroupByCollation extends Rule[LogicalPlan] { val newAggregateExpressions = a.aggregateExpressions.map { case attr: AttributeReference if aliasMap.contains(attr) => - Alias(First(attr, ignoreNulls = false), attr.name)() + Alias(First(attr, ignoreNulls = false).toAggregateExpression(), attr.name)() case other => other } From 830f63df09d8b968cde2ff7b4717df0d60bcfed5 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Wed, 10 Apr 2024 07:22:11 +0200 Subject: [PATCH 4/9] Fix tests --- .../org/apache/spark/sql/CollationSuite.scala | 72 ++++++++----------- 1 file changed, 31 insertions(+), 41 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 474642e2b525a..83170930c4df4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -24,8 +24,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} import org.apache.spark.sql.catalyst.analysis.{CollationKey, RewriteGroupByCollation} import org.apache.spark.sql.catalyst.expressions.{BindReferences, UnsafeProjection} -// import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema} @@ -375,7 +374,18 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { test("aggregates count respects collation") { Seq( - ("unicode_ci", Seq("AA", "aa"), Seq(Row(2, "aa"))), + ("utf8_binary", Seq("AAA", "aaa"), Seq(Row(1, "AAA"), Row(1, "aaa"))), + ("utf8_binary", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), + ("utf8_binary", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), + ("utf8_binary_lcase", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), + ("utf8_binary_lcase", Seq("AAA", "aaa"), Seq(Row(2, "AAA"))), + ("utf8_binary_lcase", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), + ("unicode", Seq("AAA", "aaa"), Seq(Row(1, "AAA"), Row(1, "aaa"))), + ("unicode", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), + ("unicode", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))), + ("unicode_CI", Seq("aaa", "aaa"), Seq(Row(2, "aaa"))), + ("unicode_CI", Seq("AAA", "aaa"), Seq(Row(2, "AAA"))), + ("unicode_CI", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))) ).foreach { case (collationName: String, input: Seq[String], expected: Seq[Row]) => spark.conf.set("spark.sql.codegen.wholeStage", "false") @@ -391,29 +401,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("hash agg is also used for non binary collations") { - val tableNameNonBinary = "T_NON_BINARY" - val tableNameBinary = "T_BINARY" - withTable(tableNameNonBinary) { - withTable(tableNameBinary) { - sql(s"CREATE TABLE $tableNameNonBinary (c STRING COLLATE UTF8_BINARY_LCASE) USING PARQUET") - sql(s"INSERT INTO $tableNameNonBinary VALUES ('aaa')") - sql(s"CREATE TABLE $tableNameBinary (c STRING COLLATE UTF8_BINARY) USING PARQUET") - sql(s"INSERT INTO $tableNameBinary VALUES ('aaa')") - - val dfNonBinary = sql(s"SELECT COUNT(*), c FROM $tableNameNonBinary GROUP BY c") - assert(collectFirst(dfNonBinary.queryExecution.executedPlan) { - case _: HashAggregateExec | _: ObjectHashAggregateExec => () - }.nonEmpty) - - val dfBinary = sql(s"SELECT COUNT(*), c FROM $tableNameBinary GROUP BY c") - assert(collectFirst(dfBinary.queryExecution.executedPlan) { - case _: HashAggregateExec | _: ObjectHashAggregateExec => () - }.nonEmpty) - } - } - } - test("text writing to parquet with collation enclosed with backticks") { withTempPath{ path => sql(s"select 'a' COLLATE `UNICODE`").write.parquet(path.getAbsolutePath) @@ -880,16 +867,16 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { test("RewriteGroupByCollation rule rewrites Aggregate logical plan") { val dataType = StringType(CollationFactory.collationNameToId("UNICODE_CI")) val attrRef = AttributeReference("attr", dataType)() + // original logical plan should only contain one attribute in groupingExpressions val originalPlan = Aggregate(Seq(attrRef), Seq(attrRef), LocalRelation(attrRef)) assert(originalPlan.groupingExpressions.size == 1) + assert(originalPlan.groupingExpressions.head.isInstanceOf[AttributeReference]) assert(originalPlan.groupingExpressions.head == attrRef) - // plan level rewrite should put CollationKey in Aggregate logical plan - val newPlan = RewriteGroupByCollation(originalPlan) - val groupingExpressions = newPlan.asInstanceOf[Aggregate].groupingExpressions - assert(groupingExpressions.size == 1) // only 1 alias should be present in groupingExpressions - val groupingAlias = groupingExpressions.head.asInstanceOf[Alias] - assert(groupingAlias.child.isInstanceOf[CollationKey]) // alias should be a CollationKey - assert(groupingAlias.child.containsChild(attrRef)) // CollationKey should be for attrRef + // plan level rewrite should replace attr with CollationKey in Aggregate logical plan + val newPlan = RewriteGroupByCollation(originalPlan).asInstanceOf[Aggregate] + assert(newPlan.groupingExpressions.size == 1) + assert(newPlan.groupingExpressions.head.isInstanceOf[CollationKey]) + assert(newPlan.groupingExpressions.head.asInstanceOf[CollationKey].expr == attrRef) } test("RewriteGroupByCollation rule works in SQL query analysis") { @@ -900,19 +887,22 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) df.createOrReplaceTempView("tempTable") val dfGroupBy = spark.sql("SELECT name, COUNT(*) FROM tempTable GROUP BY name") - // get the logical plan for the spark SQL query + // test RewriteGroupByCollation idempotence val logicalPlan = dfGroupBy.queryExecution.analyzed val newPlan = RewriteGroupByCollation(logicalPlan) val newNewPlan = RewriteGroupByCollation(newPlan) - assert(newPlan.isInstanceOf[Aggregate]) - assert(newNewPlan.isInstanceOf[Aggregate]) - val groupingExpressions = newPlan.asInstanceOf[Aggregate].groupingExpressions - assert(groupingExpressions.size == 1) -// val groupingAlias = groupingExpressions.head.asInstanceOf[Alias] -// assert(groupingAlias.isInstanceOf[Alias]) -// assert(groupingAlias.child.isInstanceOf[CollationKey]) + assert(newPlan == newNewPlan) // get the query execution result checkAnswer(dfGroupBy, Seq(Row("AA", 2), Row("BB", 1))) } + test("RewriteGroupByCollation doesn't disrupt aggregation on complex types") { + val table = "table_agg" + withTable(table) { + sql(s"create table $table (a array) using parquet") + sql(s"insert into $table values (array('aaa')), (array('AAA'))") + checkAnswer(sql(s"select distinct a from $table"), Seq(Row(Seq("aaa")))) + } + } + } From c0f93963f82833dc5cff55452fc327a47d861971 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Wed, 10 Apr 2024 09:19:34 +0200 Subject: [PATCH 5/9] scalastyle fix --- .../spark/sql/catalyst/analysis/RewriteGroupByCollation.scala | 3 --- .../src/test/scala/org/apache/spark/sql/CollationSuite.scala | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala index 7d6d5254f2735..6138bc14c322c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala @@ -19,12 +19,9 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale -// import scala.collection.mutable - import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExpectsInputTypes, Expression, StringTypeAnyCollation, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -// import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.CollationFactory diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 83170930c4df4..3f562aada2b57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -849,7 +849,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ("", "UNICODE_CI", "1"), ("aa", "UNICODE_CI", "6012a2a"), ("AA", "UNICODE_CI", "6012a2a"), - ("aA", "UNICODE_CI", "6012a2a"), + ("aA", "UNICODE_CI", "6012a2a") ) for ((input, collation, expected) <- testCases) { val collationId: Int = CollationFactory.collationNameToId(collation) From dc683d7e1c061ffdb4c17f45861e14a6d3b10871 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Wed, 10 Apr 2024 14:30:53 +0200 Subject: [PATCH 6/9] Remove unused imports --- .../src/test/scala/org/apache/spark/sql/CollationSuite.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 3f562aada2b57..61d84f0637c5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType} From c8855a571028bc21714da6b9a0c6fe10c45d249e Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Mon, 15 Apr 2024 06:56:55 +0200 Subject: [PATCH 7/9] Support array (in progress) --- .../sql/catalyst/util/CollationFactory.java | 11 ++ .../analysis/RewriteGroupByCollation.scala | 81 +++++++++++--- .../expressions/codegen/CodeGenerator.scala | 3 +- .../sql/catalyst/optimizer/Optimizer.scala | 3 +- .../aggregate/HashAggregateExec.scala | 2 +- .../org/apache/spark/sql/CollationSuite.scala | 101 +++++++++++++++++- 6 files changed, 177 insertions(+), 24 deletions(-) diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 119508a37e717..a4c5ea6b865db 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -214,4 +214,15 @@ public static Collation fetchCollation(String collationName) throws SparkExcepti int collationId = collationNameToId(collationName); return collationTable[collationId]; } + + public static UTF8String getCollationKeyLcase(UTF8String str) { + return str.toLowerCase(); + } + + public static UTF8String[] getCollationKeyLcase(UTF8String[] arr) { + for (int i = 0; i < arr.length; i++) { + arr[i] = arr[i].toLowerCase(); + } + return arr; + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala index 6138bc14c322c..c955994cae158 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala @@ -24,8 +24,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.CollationFactory -import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType} +import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData} +import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType} import org.apache.spark.unsafe.types.UTF8String /** @@ -36,7 +36,10 @@ object RewriteGroupByCollation extends Rule[LogicalPlan] { case a: Aggregate => val aliasMap = a.groupingExpressions.collect { case attr: AttributeReference if attr.dataType.isInstanceOf[StringType] => - attr -> CollationKey(attr) // Alias(CollationKey(attr), attr.name)() + attr -> CollationKey(attr) + case attr: AttributeReference if attr.dataType.isInstanceOf[ArrayType] + && attr.dataType.asInstanceOf[ArrayType].elementType.isInstanceOf[StringType] => + attr -> CollationKey(attr) }.toMap val newGroupingExpressions = a.groupingExpressions.map { @@ -61,13 +64,37 @@ object RewriteGroupByCollation extends Rule[LogicalPlan] { } case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes { - override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def inputTypes: Seq[AbstractDataType] = + Seq(StringTypeAnyCollation, ArrayType(StringType)) override def dataType: DataType = expr.dataType - final lazy val collationId: Int = dataType.asInstanceOf[StringType].collationId + final lazy val collationId: Int = dataType match { + case _: StringType => + dataType.asInstanceOf[StringType].collationId + case ArrayType(_: StringType, _) => + val arr = dataType.asInstanceOf[ArrayType] + arr.elementType.asInstanceOf[StringType].collationId + } - override def nullSafeEval(input: Any): Any = { - val str: UTF8String = input.asInstanceOf[UTF8String] + override def nullSafeEval(input: Any): Any = dataType match { + case _: StringType => + getCollationKey(input.asInstanceOf[UTF8String]) + case ArrayType(_: StringType, _) => + input match { + case arr: Array[UTF8String] => + arr.map(getCollationKey) + case arr: GenericArrayData => + val result = new Array[UTF8String](arr.numElements()) + for (i <- 0 until arr.numElements()) { + result(i) = getCollationKey(arr.getUTF8String(i)) + } + new GenericArrayData(result) + case _ => + None + } + } + + def getCollationKey(str: UTF8String): UTF8String = { if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { str } else if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { @@ -79,15 +106,37 @@ case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsIn } } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { - defineCodeGen(ctx, ev, c => s"$c") - } else if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { - defineCodeGen(ctx, ev, c => s"$c.toLowerCase()") - } else { - defineCodeGen(ctx, ev, c => s"UTF8String.fromBytes(CollationFactory.fetchCollation" + - s"($collationId).collator.getCollationKey($c.toString()).toByteArray())") - } + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { + case _: StringType => + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { + defineCodeGen(ctx, ev, c => s"$c") + } else if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { + defineCodeGen(ctx, ev, c => s"$c.toLowerCase()") + } else { + defineCodeGen(ctx, ev, c => s"UTF8String.fromBytes(CollationFactory.fetchCollation" + + s"($collationId).collator.getCollationKey($c.toString()).toByteArray())") + } + case ArrayType(_: StringType, _) => + val expr = ctx.addReferenceObj("this", this) + val arrData = ctx.freshName("arrData") + val arrLength = ctx.freshName("arrLength") + val arrResult = ctx.freshName("arrResult") + val arrIndex = ctx.freshName("arrIndex") + nullSafeCodeGen(ctx, ev, eval => { + s""" + |if ($eval instanceof GenericArrayData) { + | ArrayData $arrData = (ArrayData)$eval; + | int $arrLength = $arrData.numElements(); + | UTF8String[] $arrResult = new UTF8String[$arrLength]; + | for (int $arrIndex = 0; $arrIndex < $arrLength; $arrIndex++) { + | $arrResult[$arrIndex] = $expr.getCollationKey($arrData.getUTF8String($arrIndex)); + | } + | ${ev.value} = new GenericArrayData($arrResult); + |} else { + | ${ev.value} = null; + |} + """.stripMargin + }) } override protected def withNewChildInternal(newChild: Expression): Expression = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index dfe07a443a230..d03997f7f874d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.encoders.HashableWeakReference import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, MapData, SQLOrderingUtil, UnsafeRowUtils} +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, GenericArrayData, MapData, SQLOrderingUtil, UnsafeRowUtils} import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -1521,6 +1521,7 @@ object CodeGenerator extends Logging { classOf[CalendarInterval].getName, classOf[VariantVal].getName, classOf[ArrayData].getName, + classOf[GenericArrayData].getName, classOf[UnsafeArrayData].getName, classOf[MapData].getName, classOf[UnsafeMapData].getName, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 2acf9cf9e4c71..2c759dee52125 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -2223,7 +2223,8 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsPattern(DISTINCT_LIKE), ruleId) { - case Distinct(child) => Aggregate(child.output, child.output, child) + case Distinct(child) => + Aggregate(child.output, child.output, child) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index bf55e370cd83e..b29b644354ddf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -98,7 +98,7 @@ case class HashAggregateExec( case st: StringType => val str: UTF8String = row.getUTF8String(i) val collationId: Int = st.collationId - if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { newRow(i) = str } else if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { newRow(i) = UTF8String.fromString(str.toString.toLowerCase(Locale.ROOT)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 61d84f0637c5c..8610a2c136b78 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.{CollationKey, RewriteGroupByColla import org.apache.spark.sql.catalyst.expressions.{BindReferences, UnsafeProjection} import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation} -import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory} import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema} import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper @@ -35,7 +35,7 @@ import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, MapType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { @@ -802,7 +802,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("CollationKey generates correct collation key") { + test("CollationKey generates correct collation key for string") { val testCases = Seq( ("", "UTF8_BINARY", UTF8String.fromString("")), ("aa", "UTF8_BINARY", UTF8String.fromString("aa")), @@ -831,7 +831,37 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("CollationKey generates correct collation key using codegen") { + test("CollationKey generates correct collation key for array") { + val testCases = Seq( + ("", "UTF8_BINARY", Array(UTF8String.fromString(""))), + ("aa", "UTF8_BINARY", Array(UTF8String.fromString("aa"))), + ("AA", "UTF8_BINARY", Array(UTF8String.fromString("AA"))), + ("aA", "UTF8_BINARY", Array(UTF8String.fromString("aA"))), + ("", "UTF8_BINARY_LCASE", Array(UTF8String.fromString(""))), + ("aa", "UTF8_BINARY_LCASE", Array(UTF8String.fromString("aa"))), + ("AA", "UTF8_BINARY_LCASE", Array(UTF8String.fromString("aa"))), + ("aA", "UTF8_BINARY_LCASE", Array(UTF8String.fromString("aa"))), + ("", "UNICODE", Array(UTF8String.fromBytes(Array[Byte](1, 1, 0)))), + ("aa", "UNICODE", Array(UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 1, 6, 0)))), + ("AA", "UNICODE", Array(UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 1, -36, -36, 0)))), + ("aA", "UNICODE", Array(UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 1, -59, -36, 0)))), + ("", "UNICODE_CI", Array(UTF8String.fromBytes(Array[Byte](1, 0)))), + ("aa", "UNICODE_CI", Array(UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 0)))), + ("AA", "UNICODE_CI", Array(UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 0)))), + ("aA", "UNICODE_CI", Array(UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 0)))) + ) + for ((input, collation, expected) <- testCases) { + val collationId: Int = CollationFactory.collationNameToId(collation) + val array = ArrayType(StringType(collationId)) + val attrRef: AttributeReference = AttributeReference("attr", array)() + // generate CollationKey for the input string + val collationKey: CollationKey = CollationKey(attrRef) + val arr: Array[UTF8String] = Array(UTF8String.fromString(input)) + assert(collationKey.nullSafeEval(arr) === expected) + } + } + + test("CollationKey generates correct collation key for string using codegen") { val testCases = Seq( ("", "UTF8_BINARY", ""), ("aa", "UTF8_BINARY", "6161"), @@ -863,6 +893,40 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } + test("CollationKey generates correct collation key for array using codegen") { + val testCases = Seq( + (Seq(""), "UTF8_BINARY", ""), + (Seq("aa"), "UTF8_BINARY", "6161"), + (Seq("AA"), "UTF8_BINARY", "4141"), + (Seq("aA"), "UTF8_BINARY", "4161"), + (Seq(""), "UTF8_BINARY_LCASE", ""), + (Seq("aa"), "UTF8_BINARY_LCASE", "6161"), + (Seq("AA"), "UTF8_BINARY_LCASE", "6161"), + (Seq("aA"), "UTF8_BINARY_LCASE", "6161"), + (Seq(""), "UNICODE", "101"), + (Seq("aa"), "UNICODE", "60106012a2a"), + (Seq("AA"), "UNICODE", "dcdc0106012a2a"), + (Seq("aA"), "UNICODE", "dcc50106012a2a"), + (Seq(""), "UNICODE_CI", "1"), + (Seq("aa"), "UNICODE_CI", "6012a2a"), + (Seq("AA"), "UNICODE_CI", "6012a2a"), + (Seq("aA"), "UNICODE_CI", "6012a2a") + ) + for ((input, collation, expected) <- testCases) { + val collationId: Int = CollationFactory.collationNameToId(collation) + val array = ArrayType(StringType(collationId)) + val attrRef: AttributeReference = AttributeReference("attr", array)() + // generate CollationKey for the input string + val collationKey: CollationKey = CollationKey(attrRef) + val arr: Seq[UTF8String] = input.map(UTF8String.fromString) + val arrData: ArrayData = ArrayData.toArrayData(arr) + val boundExpr = BindReferences.bindReference(collationKey, Seq(attrRef)) + val ev = UnsafeProjection.create(Array(boundExpr).toIndexedSeq) + val strProj = ev.apply(InternalRow(arrData)) + assert(strProj.toString.split(',').last.startsWith(expected)) + } + } + test("RewriteGroupByCollation rule rewrites Aggregate logical plan") { val dataType = StringType(CollationFactory.collationNameToId("UNICODE_CI")) val attrRef = AttributeReference("attr", dataType)() @@ -895,7 +959,34 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkAnswer(dfGroupBy, Seq(Row("AA", 2), Row("BB", 1))) } - test("RewriteGroupByCollation doesn't disrupt aggregation on complex types") { + test("RewriteGroupByCollation rule works in SQL query analysis with array type") { + spark.conf.set("spark.sql.codegen.wholeStage", value = false) + val collationId = CollationFactory.collationNameToId("UNICODE_CI") + val dataType = ArrayType(StringType(collationId)) + val schema = StructType(Seq(StructField("name", dataType))) + val data = Seq(Row(Seq("AA")), Row(Seq("aa")), Row(Seq("BB"))) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.createOrReplaceTempView("tempTable") + val dfGroupBy = spark.sql("SELECT name, COUNT(*) FROM tempTable GROUP BY name") + // test RewriteGroupByCollation idempotence + val logicalPlan = dfGroupBy.queryExecution.analyzed + val newPlan = RewriteGroupByCollation(logicalPlan) + val newNewPlan = RewriteGroupByCollation(newPlan) + assert(newPlan == newNewPlan) + // get the query execution result + checkAnswer(dfGroupBy, Seq(Row(Seq("AA"), 2), Row(Seq("BB"), 1))) + } + + test("Hash aggregation works on string type") { + val table = "table_agg" + withTable(table) { + sql(s"create table $table (a string collate utf8_binary_lcase) using parquet") + sql(s"insert into $table values ('aaa'), ('AAA')") + checkAnswer(sql(s"select distinct a from $table"), Seq(Row("aaa"))) + } + } + + test("Hash aggregation works on array type") { val table = "table_agg" withTable(table) { sql(s"create table $table (a array) using parquet") From 05c03787ace66b5afbe59736ea73ee049843299c Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Mon, 15 Apr 2024 21:30:04 +0200 Subject: [PATCH 8/9] Fix rewrite for array --- .../sql/catalyst/analysis/RewriteGroupByCollation.scala | 9 +++++---- .../test/scala/org/apache/spark/sql/CollationSuite.scala | 7 +++++-- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala index c955994cae158..b173731f604b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExpectsInputTypes, Expression, StringTypeAnyCollation, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExpectsInputTypes, Expression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.First import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData} -import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation} +import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType, TypeCollection} import org.apache.spark.unsafe.types.UTF8String /** @@ -65,7 +66,7 @@ object RewriteGroupByCollation extends Rule[LogicalPlan] { case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = - Seq(StringTypeAnyCollation, ArrayType(StringType)) + Seq(TypeCollection(StringTypeAnyCollation, AbstractArrayType(StringTypeAnyCollation))) override def dataType: DataType = expr.dataType final lazy val collationId: Int = dataType match { @@ -124,7 +125,7 @@ case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsIn val arrIndex = ctx.freshName("arrIndex") nullSafeCodeGen(ctx, ev, eval => { s""" - |if ($eval instanceof GenericArrayData) { + |if ($eval instanceof ArrayData) { | ArrayData $arrData = (ArrayData)$eval; | int $arrLength = $arrData.numElements(); | UTF8String[] $arrResult = new UTF8String[$arrLength]; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index becdab8fb4ffd..187ca31ff7d4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -988,6 +988,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val attrRef: AttributeReference = AttributeReference("attr", StringType(collationId))() // generate CollationKey for the input string val collationKey: CollationKey = CollationKey(attrRef) + assert(collationKey.resolved) val str: UTF8String = UTF8String.fromString(input) assert(collationKey.nullSafeEval(str) === expected) } @@ -1018,6 +1019,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { val attrRef: AttributeReference = AttributeReference("attr", array)() // generate CollationKey for the input string val collationKey: CollationKey = CollationKey(attrRef) + assert(collationKey.resolved) val arr: Array[UTF8String] = Array(UTF8String.fromString(input)) assert(collationKey.nullSafeEval(arr) === expected) } @@ -1099,13 +1101,14 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { assert(originalPlan.groupingExpressions.head == attrRef) // plan level rewrite should replace attr with CollationKey in Aggregate logical plan val newPlan = RewriteGroupByCollation(originalPlan).asInstanceOf[Aggregate] + assert(newPlan.resolved) assert(newPlan.groupingExpressions.size == 1) assert(newPlan.groupingExpressions.head.isInstanceOf[CollationKey]) assert(newPlan.groupingExpressions.head.asInstanceOf[CollationKey].expr == attrRef) } test("RewriteGroupByCollation rule works in SQL query analysis") { - spark.conf.set("spark.sql.codegen.wholeStage", value = false) +// spark.conf.set("spark.sql.codegen.wholeStage", value = false) val dataType = StringType(CollationFactory.collationNameToId("UNICODE_CI")) val schema = StructType(Seq(StructField("name", dataType))) val data = Seq(Row("AA"), Row("aa"), Row("BB")) @@ -1122,7 +1125,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } test("RewriteGroupByCollation rule works in SQL query analysis with array type") { - spark.conf.set("spark.sql.codegen.wholeStage", value = false) +// spark.conf.set("spark.sql.codegen.wholeStage", value = false) val collationId = CollationFactory.collationNameToId("UNICODE_CI") val dataType = ArrayType(StringType(collationId)) val schema = StructType(Seq(StructField("name", dataType))) From ec447e308d756303c90bfdddbd15cd0341de7831 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Wed, 17 Apr 2024 10:41:16 +0200 Subject: [PATCH 9/9] GROUP BY / DISTINCT tests --- .../org/apache/spark/sql/CollationSuite.scala | 26 ++++++++++++++++--- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 187ca31ff7d4f..61fe7333be8d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -1142,21 +1142,39 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkAnswer(dfGroupBy, Seq(Row(Seq("AA"), 2), Row(Seq("BB"), 1))) } - test("Hash aggregation works on string type") { + test("DISTINCT works on string type") { val table = "table_agg" withTable(table) { sql(s"create table $table (a string collate utf8_binary_lcase) using parquet") sql(s"insert into $table values ('aaa'), ('AAA')") - checkAnswer(sql(s"select distinct a from $table"), Seq(Row("aaa"))) + checkAnswer(sql(s"select distinct a from $table"), Seq(Row("axa"))) } } - test("Hash aggregation works on array type") { + test("DISTINCT works on array type") { val table = "table_agg" withTable(table) { sql(s"create table $table (a array) using parquet") sql(s"insert into $table values (array('aaa')), (array('AAA'))") - checkAnswer(sql(s"select distinct a from $table"), Seq(Row(Seq("aaa")))) + checkAnswer(sql(s"select distinct a from $table"), Seq(Row(Seq("axa")))) + } + } + + test("GROUP BY works on string type") { + val table = "table_agg" + withTable(table) { + sql(s"create table $table (a string collate utf8_binary_lcase) using parquet") + sql(s"insert into $table values ('aaa'), ('AAA')") + checkAnswer(sql(s"select a, count(*) from $table group by a"), Seq(Row("axa", 2))) + } + } + + test("GROUP BY works on array type") { + val table = "table_agg" + withTable(table) { + sql(s"create table $table (a array) using parquet") + sql(s"insert into $table values (array('aaa')), (array('AAA'))") + checkAnswer(sql(s"select a, count(*) from $table group by a"), Seq(Row(Seq("axa"), 2))) } }