diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 9786c559da44b..726bc20b448f9 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -259,4 +259,15 @@ public static Collation fetchCollation(String collationName) throws SparkExcepti int collationId = collationNameToId(collationName); return collationTable[collationId]; } + + public static UTF8String getCollationKeyLcase(UTF8String str) { + return str.toLowerCase(); + } + + public static UTF8String[] getCollationKeyLcase(UTF8String[] arr) { + for (int i = 0; i < arr.length; i++) { + arr[i] = arr[i].toLowerCase(); + } + return arr; + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala new file mode 100644 index 0000000000000..b173731f604b9 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteGroupByCollation.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import java.util.Locale + +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExpectsInputTypes, Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.First +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData} +import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation} +import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType, TypeCollection} +import org.apache.spark.unsafe.types.UTF8String + +/** + * A rule that rewrites aggregation operations using plans that operate on collated strings. + */ +object RewriteGroupByCollation extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput { + case a: Aggregate => + val aliasMap = a.groupingExpressions.collect { + case attr: AttributeReference if attr.dataType.isInstanceOf[StringType] => + attr -> CollationKey(attr) + case attr: AttributeReference if attr.dataType.isInstanceOf[ArrayType] + && attr.dataType.asInstanceOf[ArrayType].elementType.isInstanceOf[StringType] => + attr -> CollationKey(attr) + }.toMap + + val newGroupingExpressions = a.groupingExpressions.map { + case attr: AttributeReference if aliasMap.contains(attr) => + aliasMap(attr) + case other => other + } + + val newAggregateExpressions = a.aggregateExpressions.map { + case attr: AttributeReference if aliasMap.contains(attr) => + Alias(First(attr, ignoreNulls = false).toAggregateExpression(), attr.name)() + case other => other + } + + val newAggregate = a.copy( + groupingExpressions = newGroupingExpressions, + aggregateExpressions = newAggregateExpressions + ) + + (newAggregate, a.output.zip(newAggregate.output)) + } +} + +case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = + Seq(TypeCollection(StringTypeAnyCollation, AbstractArrayType(StringTypeAnyCollation))) + override def dataType: DataType = expr.dataType + + final lazy val collationId: Int = dataType match { + case _: StringType => + dataType.asInstanceOf[StringType].collationId + case ArrayType(_: StringType, _) => + val arr = dataType.asInstanceOf[ArrayType] + arr.elementType.asInstanceOf[StringType].collationId + } + + override def nullSafeEval(input: Any): Any = dataType match { + case _: StringType => + getCollationKey(input.asInstanceOf[UTF8String]) + case ArrayType(_: StringType, _) => + input match { + case arr: Array[UTF8String] => + arr.map(getCollationKey) + case arr: GenericArrayData => + val result = new Array[UTF8String](arr.numElements()) + for (i <- 0 until arr.numElements()) { + result(i) = getCollationKey(arr.getUTF8String(i)) + } + new GenericArrayData(result) + case _ => + None + } + } + + def getCollationKey(str: UTF8String): UTF8String = { + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { + str + } else if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { + UTF8String.fromString(str.toString.toLowerCase(Locale.ROOT)) + } else { + val collator = CollationFactory.fetchCollation(collationId).collator + val collationKey = collator.getCollationKey(str.toString) + UTF8String.fromBytes(collationKey.toByteArray) + } + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { + case _: StringType => + if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) { + defineCodeGen(ctx, ev, c => s"$c") + } else if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { + defineCodeGen(ctx, ev, c => s"$c.toLowerCase()") + } else { + defineCodeGen(ctx, ev, c => s"UTF8String.fromBytes(CollationFactory.fetchCollation" + + s"($collationId).collator.getCollationKey($c.toString()).toByteArray())") + } + case ArrayType(_: StringType, _) => + val expr = ctx.addReferenceObj("this", this) + val arrData = ctx.freshName("arrData") + val arrLength = ctx.freshName("arrLength") + val arrResult = ctx.freshName("arrResult") + val arrIndex = ctx.freshName("arrIndex") + nullSafeCodeGen(ctx, ev, eval => { + s""" + |if ($eval instanceof ArrayData) { + | ArrayData $arrData = (ArrayData)$eval; + | int $arrLength = $arrData.numElements(); + | UTF8String[] $arrResult = new UTF8String[$arrLength]; + | for (int $arrIndex = 0; $arrIndex < $arrLength; $arrIndex++) { + | $arrResult[$arrIndex] = $expr.getCollationKey($arrData.getUTF8String($arrIndex)); + | } + | ${ev.value} = new GenericArrayData($arrResult); + |} else { + | ${ev.value} = null; + |} + """.stripMargin + }) + } + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(expr = newChild) + } + + override def child: Expression = expr +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala index 258bc0ed8fe73..9642e8440bc48 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExprUtils.scala @@ -168,7 +168,13 @@ object ExprUtils extends QueryErrorsBase { a.failAnalysis( errorClass = "MISSING_GROUP_BY", messageParameters = Map.empty) - case e: Attribute if !a.groupingExpressions.exists(_.semanticEquals(e)) => + case e: Attribute if !a.groupingExpressions.exists(_.semanticEquals(e)) && + !a.groupingExpressions.exists { + case al: Alias => al.child.semanticEquals(CollationKey(e)) + case ck: CollationKey => ck.semanticEquals(CollationKey(e)) + case other => other.semanticEquals(e) + } + => throw QueryCompilationErrors.columnNotInGroupByClauseError(e) case s: ScalarSubquery if s.children.nonEmpty && !a.groupingExpressions.exists(_.semanticEquals(s)) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 5aa766a60c106..39d6f955916db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.encoders.HashableWeakReference import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.types._ -import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, CollationSupport, MapData, SQLOrderingUtil, UnsafeRowUtils} +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, CollationSupport, GenericArrayData, MapData, SQLOrderingUtil, UnsafeRowUtils} import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.internal.SQLConf @@ -1523,6 +1523,7 @@ object CodeGenerator extends Logging { classOf[CalendarInterval].getName, classOf[VariantVal].getName, classOf[ArrayData].getName, + classOf[GenericArrayData].getName, classOf[UnsafeArrayData].getName, classOf[MapData].getName, classOf[UnsafeMapData].getName, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala index 2d1e71a63a8ce..24f10854f7dfe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala @@ -353,15 +353,10 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { case a: AggregateExpression => a }) } - val groupByExpressionSeq = Seq(newPlan, cachedPlan).map(_.groupingExpressions) val Seq(newPlanSupportsHashAggregate, cachedPlanSupportsHashAggregate) = - aggregateExpressionsSeq.zip(groupByExpressionSeq).map { - case (aggregateExpressions, groupByExpressions) => - Aggregate.supportsHashAggregate( - aggregateExpressions.flatMap( - _.aggregateFunction.aggBufferAttributes), groupByExpressions) - } + aggregateExpressionsSeq.map(aggregateExpressions => Aggregate.supportsHashAggregate( + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))) newPlanSupportsHashAggregate && cachedPlanSupportsHashAggregate || newPlanSupportsHashAggregate == cachedPlanSupportsHashAggregate && { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index cacde9f5a7122..177195a43b976 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -296,6 +296,7 @@ abstract class Optimizer(catalogManager: CatalogManager) EliminateView, ReplaceExpressions, RewriteNonCorrelatedExists, + RewriteGroupByCollation, PullOutGroupingExpressions, ComputeCurrentTime, ReplaceCurrentLike(catalogManager), @@ -2231,7 +2232,8 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] { object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning( _.containsPattern(DISTINCT_LIKE), ruleId) { - case Distinct(child) => Aggregate(child.output, child.output, child) + case Distinct(child) => + Aggregate(child.output, child.output, child) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutGroupingExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutGroupingExpressions.scala index ecc3619d584f5..13738963dd171 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutGroupingExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/PullOutGroupingExpressions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable +import org.apache.spark.sql.catalyst.analysis.CollationKey import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} @@ -51,7 +52,7 @@ object PullOutGroupingExpressions extends Rule[LogicalPlan] { case a: Aggregate if a.resolved => val complexGroupingExpressionMap = mutable.LinkedHashMap.empty[Expression, NamedExpression] val newGroupingExpressions = a.groupingExpressions.toIndexedSeq.map { - case e if !e.foldable && e.children.nonEmpty => + case e if !e.foldable && e.children.nonEmpty && !e.isInstanceOf[CollationKey] => complexGroupingExpressionMap .getOrElseUpdate(e.canonicalized, Alias(e, "_groupingexpression")()) .toAttribute diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 1c8f7a97dd7fe..7c2dfd31f4e33 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -1232,11 +1232,9 @@ object Aggregate { schema.forall(f => UnsafeRow.isMutable(f.dataType)) } - def supportsHashAggregate( - aggregateBufferAttributes: Seq[Attribute], groupingExpression: Seq[Expression]): Boolean = { + def supportsHashAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = { val aggregationBufferSchema = DataTypeUtils.fromAttributes(aggregateBufferAttributes) - isAggregateBufferMutable(aggregationBufferSchema) && - groupingExpression.forall(e => UnsafeRowUtils.isBinaryStable(e.dataType)) + isAggregateBufferMutable(aggregationBufferSchema) } def supportsObjectHashAggregate(aggregateExpressions: Seq[AggregateExpression]): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index 367d4cfafb485..278c1fc3f73b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -76,7 +76,7 @@ object AggUtils { resultExpressions: Seq[NamedExpression] = Nil, child: SparkPlan): SparkPlan = { val useHash = Aggregate.supportsHashAggregate( - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes), groupingExpressions) + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) val forceObjHashAggregate = forceApplyObjectHashAggregate(child.conf) val forceSortAggregate = forceApplySortAggregate(child.conf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index bdf17607d77c5..b29b644354ddf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.aggregate +import java.util.Locale import java.util.concurrent.TimeUnit._ import scala.collection.mutable @@ -33,17 +34,19 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.Aggregate import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes +import org.apache.spark.sql.catalyst.util.{truncatedString, CollationFactory} import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS -import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{CalendarIntervalType, DecimalType, StringType} +import org.apache.spark.sql.types.{CalendarIntervalType, DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils + /** * Hash-based aggregate operator that can also fallback to sorting when data exceeds memory size. */ @@ -59,7 +62,7 @@ case class HashAggregateExec( child: SparkPlan) extends AggregateCodegenSupport { - require(Aggregate.supportsHashAggregate(aggregateBufferAttributes, groupingExpressions)) + require(Aggregate.supportsHashAggregate(aggregateBufferAttributes)) override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ @@ -87,6 +90,31 @@ case class HashAggregateExec( } } + private def collationAwareStringRows(row: InternalRow, schema: StructType): InternalRow = { + val newRow = new Array[Any](schema.length) + for (i <- schema.fields.indices) { + val field = schema.fields(i) + field.dataType match { + case st: StringType => + val str: UTF8String = row.getUTF8String(i) + val collationId: Int = st.collationId + if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) { + newRow(i) = str + } else if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) { + newRow(i) = UTF8String.fromString(str.toString.toLowerCase(Locale.ROOT)) + } else { + val collator = CollationFactory.fetchCollation(collationId).collator + val collationKey = collator.getCollationKey(str.toString).toByteArray + newRow(i) = UTF8String.fromString(collationKey.map("%02x" format _).mkString) + } + case _ => + newRow(i) = row.get(i, field.dataType) + } + } + val project = UnsafeProjection.create(schema) + project(InternalRow.fromSeq(newRow.toIndexedSeq)) + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val peakMemory = longMetric("peakMemory") @@ -96,9 +124,9 @@ case class HashAggregateExec( val numTasksFallBacked = longMetric("numTasksFallBacked") child.execute().mapPartitionsWithIndex { (partIndex, iter) => - +// val collationAwareIterator = iter.map(row => collationAwareStringRows(row, child.schema)) val beforeAgg = System.nanoTime() - val hasInput = iter.hasNext + val hasInput = iter.hasNext // collationAwareIterator.hasNext val res = if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input iterator is empty, // so return an empty iterator. @@ -115,7 +143,7 @@ case class HashAggregateExec( (expressions, inputSchema) => MutableProjection.create(expressions, inputSchema), inputAttributes, - iter, + iter, // collationAwareIterator testFallbackStartsAt, numOutputRows, peakMemory, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index c4ddd25c99b6c..61fe7333be8d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -21,18 +21,22 @@ import scala.collection.immutable.Seq import scala.jdk.CollectionConverters.MapHasAsJava import org.apache.spark.SparkException -import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} +import org.apache.spark.sql.catalyst.analysis.{CollationKey, RewriteGroupByCollation} +import org.apache.spark.sql.catalyst.expressions.{BindReferences, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LocalRelation} +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory} import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema} import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper -import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.internal.SqlApiConf -import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, MapType, StringType, StructField, StructType} +import org.apache.spark.unsafe.types.UTF8String class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName @@ -287,6 +291,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { ("unicode_CI", Seq("aaa", "bbb"), Seq(Row(1, "aaa"), Row(1, "bbb"))) ).foreach { case (collationName: String, input: Seq[String], expected: Seq[Row]) => + spark.conf.set("spark.sql.codegen.wholeStage", "false") checkAnswer(sql( s""" with t as ( @@ -299,29 +304,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } - test("hash agg is not used for non binary collations") { - val tableNameNonBinary = "T_NON_BINARY" - val tableNameBinary = "T_BINARY" - withTable(tableNameNonBinary) { - withTable(tableNameBinary) { - sql(s"CREATE TABLE $tableNameNonBinary (c STRING COLLATE UTF8_BINARY_LCASE) USING PARQUET") - sql(s"INSERT INTO $tableNameNonBinary VALUES ('aaa')") - sql(s"CREATE TABLE $tableNameBinary (c STRING COLLATE UTF8_BINARY) USING PARQUET") - sql(s"INSERT INTO $tableNameBinary VALUES ('aaa')") - - val dfNonBinary = sql(s"SELECT COUNT(*), c FROM $tableNameNonBinary GROUP BY c") - assert(collectFirst(dfNonBinary.queryExecution.executedPlan) { - case _: HashAggregateExec | _: ObjectHashAggregateExec => () - }.isEmpty) - - val dfBinary = sql(s"SELECT COUNT(*), c FROM $tableNameBinary GROUP BY c") - assert(collectFirst(dfBinary.queryExecution.executedPlan) { - case _: HashAggregateExec | _: ObjectHashAggregateExec => () - }.nonEmpty) - } - } - } - test("text writing to parquet with collation enclosed with backticks") { withTempPath{ path => sql(s"select 'a' COLLATE `UNICODE`").write.parquet(path.getAbsolutePath) @@ -981,4 +963,219 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { checkAnswer(dfNonBinary, dfBinary) } } + + test("CollationKey generates correct collation key for string") { + val testCases = Seq( + ("", "UTF8_BINARY", UTF8String.fromString("")), + ("aa", "UTF8_BINARY", UTF8String.fromString("aa")), + ("AA", "UTF8_BINARY", UTF8String.fromString("AA")), + ("aA", "UTF8_BINARY", UTF8String.fromString("aA")), + ("", "UTF8_BINARY_LCASE", UTF8String.fromString("")), + ("aa", "UTF8_BINARY_LCASE", UTF8String.fromString("aa")), + ("AA", "UTF8_BINARY_LCASE", UTF8String.fromString("aa")), + ("aA", "UTF8_BINARY_LCASE", UTF8String.fromString("aa")), + ("", "UNICODE", UTF8String.fromBytes(Array[Byte](1, 1, 0))), + ("aa", "UNICODE", UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 1, 6, 0))), + ("AA", "UNICODE", UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 1, -36, -36, 0))), + ("aA", "UNICODE", UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 1, -59, -36, 0))), + ("", "UNICODE_CI", UTF8String.fromBytes(Array[Byte](1, 0))), + ("aa", "UNICODE_CI", UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 0))), + ("AA", "UNICODE_CI", UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 0))), + ("aA", "UNICODE_CI", UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 0))) + ) + for ((input, collation, expected) <- testCases) { + val collationId: Int = CollationFactory.collationNameToId(collation) + val attrRef: AttributeReference = AttributeReference("attr", StringType(collationId))() + // generate CollationKey for the input string + val collationKey: CollationKey = CollationKey(attrRef) + assert(collationKey.resolved) + val str: UTF8String = UTF8String.fromString(input) + assert(collationKey.nullSafeEval(str) === expected) + } + } + + test("CollationKey generates correct collation key for array") { + val testCases = Seq( + ("", "UTF8_BINARY", Array(UTF8String.fromString(""))), + ("aa", "UTF8_BINARY", Array(UTF8String.fromString("aa"))), + ("AA", "UTF8_BINARY", Array(UTF8String.fromString("AA"))), + ("aA", "UTF8_BINARY", Array(UTF8String.fromString("aA"))), + ("", "UTF8_BINARY_LCASE", Array(UTF8String.fromString(""))), + ("aa", "UTF8_BINARY_LCASE", Array(UTF8String.fromString("aa"))), + ("AA", "UTF8_BINARY_LCASE", Array(UTF8String.fromString("aa"))), + ("aA", "UTF8_BINARY_LCASE", Array(UTF8String.fromString("aa"))), + ("", "UNICODE", Array(UTF8String.fromBytes(Array[Byte](1, 1, 0)))), + ("aa", "UNICODE", Array(UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 1, 6, 0)))), + ("AA", "UNICODE", Array(UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 1, -36, -36, 0)))), + ("aA", "UNICODE", Array(UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 1, -59, -36, 0)))), + ("", "UNICODE_CI", Array(UTF8String.fromBytes(Array[Byte](1, 0)))), + ("aa", "UNICODE_CI", Array(UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 0)))), + ("AA", "UNICODE_CI", Array(UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 0)))), + ("aA", "UNICODE_CI", Array(UTF8String.fromBytes(Array[Byte](42, 42, 1, 6, 0)))) + ) + for ((input, collation, expected) <- testCases) { + val collationId: Int = CollationFactory.collationNameToId(collation) + val array = ArrayType(StringType(collationId)) + val attrRef: AttributeReference = AttributeReference("attr", array)() + // generate CollationKey for the input string + val collationKey: CollationKey = CollationKey(attrRef) + assert(collationKey.resolved) + val arr: Array[UTF8String] = Array(UTF8String.fromString(input)) + assert(collationKey.nullSafeEval(arr) === expected) + } + } + + test("CollationKey generates correct collation key for string using codegen") { + val testCases = Seq( + ("", "UTF8_BINARY", ""), + ("aa", "UTF8_BINARY", "6161"), + ("AA", "UTF8_BINARY", "4141"), + ("aA", "UTF8_BINARY", "4161"), + ("", "UTF8_BINARY_LCASE", ""), + ("aa", "UTF8_BINARY_LCASE", "6161"), + ("AA", "UTF8_BINARY_LCASE", "6161"), + ("aA", "UTF8_BINARY_LCASE", "6161"), + ("", "UNICODE", "101"), + ("aa", "UNICODE", "60106012a2a"), + ("AA", "UNICODE", "dcdc0106012a2a"), + ("aA", "UNICODE", "dcc50106012a2a"), + ("", "UNICODE_CI", "1"), + ("aa", "UNICODE_CI", "6012a2a"), + ("AA", "UNICODE_CI", "6012a2a"), + ("aA", "UNICODE_CI", "6012a2a") + ) + for ((input, collation, expected) <- testCases) { + val collationId: Int = CollationFactory.collationNameToId(collation) + val attrRef: AttributeReference = AttributeReference("attr", StringType(collationId))() + // generate CollationKey for the input string + val collationKey: CollationKey = CollationKey(attrRef) + val str: UTF8String = UTF8String.fromString(input) + val boundExpr = BindReferences.bindReference(collationKey, Seq(attrRef)) + val ev = UnsafeProjection.create(Array(boundExpr).toIndexedSeq) + val strProj = ev.apply(InternalRow(str)) + assert(strProj.toString.split(',').last.startsWith(expected)) + } + } + + test("CollationKey generates correct collation key for array using codegen") { + val testCases = Seq( + (Seq(""), "UTF8_BINARY", ""), + (Seq("aa"), "UTF8_BINARY", "6161"), + (Seq("AA"), "UTF8_BINARY", "4141"), + (Seq("aA"), "UTF8_BINARY", "4161"), + (Seq(""), "UTF8_BINARY_LCASE", ""), + (Seq("aa"), "UTF8_BINARY_LCASE", "6161"), + (Seq("AA"), "UTF8_BINARY_LCASE", "6161"), + (Seq("aA"), "UTF8_BINARY_LCASE", "6161"), + (Seq(""), "UNICODE", "101"), + (Seq("aa"), "UNICODE", "60106012a2a"), + (Seq("AA"), "UNICODE", "dcdc0106012a2a"), + (Seq("aA"), "UNICODE", "dcc50106012a2a"), + (Seq(""), "UNICODE_CI", "1"), + (Seq("aa"), "UNICODE_CI", "6012a2a"), + (Seq("AA"), "UNICODE_CI", "6012a2a"), + (Seq("aA"), "UNICODE_CI", "6012a2a") + ) + for ((input, collation, expected) <- testCases) { + val collationId: Int = CollationFactory.collationNameToId(collation) + val array = ArrayType(StringType(collationId)) + val attrRef: AttributeReference = AttributeReference("attr", array)() + // generate CollationKey for the input string + val collationKey: CollationKey = CollationKey(attrRef) + val arr: Seq[UTF8String] = input.map(UTF8String.fromString) + val arrData: ArrayData = ArrayData.toArrayData(arr) + val boundExpr = BindReferences.bindReference(collationKey, Seq(attrRef)) + val ev = UnsafeProjection.create(Array(boundExpr).toIndexedSeq) + val strProj = ev.apply(InternalRow(arrData)) + assert(strProj.toString.split(',').last.startsWith(expected)) + } + } + + test("RewriteGroupByCollation rule rewrites Aggregate logical plan") { + val dataType = StringType(CollationFactory.collationNameToId("UNICODE_CI")) + val attrRef = AttributeReference("attr", dataType)() + // original logical plan should only contain one attribute in groupingExpressions + val originalPlan = Aggregate(Seq(attrRef), Seq(attrRef), LocalRelation(attrRef)) + assert(originalPlan.groupingExpressions.size == 1) + assert(originalPlan.groupingExpressions.head.isInstanceOf[AttributeReference]) + assert(originalPlan.groupingExpressions.head == attrRef) + // plan level rewrite should replace attr with CollationKey in Aggregate logical plan + val newPlan = RewriteGroupByCollation(originalPlan).asInstanceOf[Aggregate] + assert(newPlan.resolved) + assert(newPlan.groupingExpressions.size == 1) + assert(newPlan.groupingExpressions.head.isInstanceOf[CollationKey]) + assert(newPlan.groupingExpressions.head.asInstanceOf[CollationKey].expr == attrRef) + } + + test("RewriteGroupByCollation rule works in SQL query analysis") { +// spark.conf.set("spark.sql.codegen.wholeStage", value = false) + val dataType = StringType(CollationFactory.collationNameToId("UNICODE_CI")) + val schema = StructType(Seq(StructField("name", dataType))) + val data = Seq(Row("AA"), Row("aa"), Row("BB")) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.createOrReplaceTempView("tempTable") + val dfGroupBy = spark.sql("SELECT name, COUNT(*) FROM tempTable GROUP BY name") + // test RewriteGroupByCollation idempotence + val logicalPlan = dfGroupBy.queryExecution.analyzed + val newPlan = RewriteGroupByCollation(logicalPlan) + val newNewPlan = RewriteGroupByCollation(newPlan) + assert(newPlan == newNewPlan) + // get the query execution result + checkAnswer(dfGroupBy, Seq(Row("AA", 2), Row("BB", 1))) + } + + test("RewriteGroupByCollation rule works in SQL query analysis with array type") { +// spark.conf.set("spark.sql.codegen.wholeStage", value = false) + val collationId = CollationFactory.collationNameToId("UNICODE_CI") + val dataType = ArrayType(StringType(collationId)) + val schema = StructType(Seq(StructField("name", dataType))) + val data = Seq(Row(Seq("AA")), Row(Seq("aa")), Row(Seq("BB"))) + val df = spark.createDataFrame(spark.sparkContext.parallelize(data), schema) + df.createOrReplaceTempView("tempTable") + val dfGroupBy = spark.sql("SELECT name, COUNT(*) FROM tempTable GROUP BY name") + // test RewriteGroupByCollation idempotence + val logicalPlan = dfGroupBy.queryExecution.analyzed + val newPlan = RewriteGroupByCollation(logicalPlan) + val newNewPlan = RewriteGroupByCollation(newPlan) + assert(newPlan == newNewPlan) + // get the query execution result + checkAnswer(dfGroupBy, Seq(Row(Seq("AA"), 2), Row(Seq("BB"), 1))) + } + + test("DISTINCT works on string type") { + val table = "table_agg" + withTable(table) { + sql(s"create table $table (a string collate utf8_binary_lcase) using parquet") + sql(s"insert into $table values ('aaa'), ('AAA')") + checkAnswer(sql(s"select distinct a from $table"), Seq(Row("axa"))) + } + } + + test("DISTINCT works on array type") { + val table = "table_agg" + withTable(table) { + sql(s"create table $table (a array) using parquet") + sql(s"insert into $table values (array('aaa')), (array('AAA'))") + checkAnswer(sql(s"select distinct a from $table"), Seq(Row(Seq("axa")))) + } + } + + test("GROUP BY works on string type") { + val table = "table_agg" + withTable(table) { + sql(s"create table $table (a string collate utf8_binary_lcase) using parquet") + sql(s"insert into $table values ('aaa'), ('AAA')") + checkAnswer(sql(s"select a, count(*) from $table group by a"), Seq(Row("axa", 2))) + } + } + + test("GROUP BY works on array type") { + val table = "table_agg" + withTable(table) { + sql(s"create table $table (a array) using parquet") + sql(s"insert into $table values (array('aaa')), (array('AAA'))") + checkAnswer(sql(s"select a, count(*) from $table group by a"), Seq(Row(Seq("axa"), 2))) + } + } + }