Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -259,4 +259,15 @@ public static Collation fetchCollation(String collationName) throws SparkExcepti
int collationId = collationNameToId(collationName);
return collationTable[collationId];
}

public static UTF8String getCollationKeyLcase(UTF8String str) {
return str.toLowerCase();
}

public static UTF8String[] getCollationKeyLcase(UTF8String[] arr) {
for (int i = 0; i < arr.length; i++) {
arr[i] = arr[i].toLowerCase();
}
return arr;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import java.util.Locale

import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, ExpectsInputTypes, Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.First
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData}
import org.apache.spark.sql.internal.types.{AbstractArrayType, StringTypeAnyCollation}
import org.apache.spark.sql.types.{AbstractDataType, ArrayType, DataType, StringType, TypeCollection}
import org.apache.spark.unsafe.types.UTF8String

/**
* A rule that rewrites aggregation operations using plans that operate on collated strings.
*/
object RewriteGroupByCollation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
case a: Aggregate =>
val aliasMap = a.groupingExpressions.collect {
case attr: AttributeReference if attr.dataType.isInstanceOf[StringType] =>
attr -> CollationKey(attr)
case attr: AttributeReference if attr.dataType.isInstanceOf[ArrayType]
&& attr.dataType.asInstanceOf[ArrayType].elementType.isInstanceOf[StringType] =>
attr -> CollationKey(attr)
}.toMap

val newGroupingExpressions = a.groupingExpressions.map {
case attr: AttributeReference if aliasMap.contains(attr) =>
aliasMap(attr)
case other => other
}

val newAggregateExpressions = a.aggregateExpressions.map {
case attr: AttributeReference if aliasMap.contains(attr) =>
Alias(First(attr, ignoreNulls = false).toAggregateExpression(), attr.name)()
case other => other
}

val newAggregate = a.copy(
groupingExpressions = newGroupingExpressions,
aggregateExpressions = newAggregateExpressions
)

(newAggregate, a.output.zip(newAggregate.output))
}
}

case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] =
Seq(TypeCollection(StringTypeAnyCollation, AbstractArrayType(StringTypeAnyCollation)))
override def dataType: DataType = expr.dataType

final lazy val collationId: Int = dataType match {
case _: StringType =>
dataType.asInstanceOf[StringType].collationId
case ArrayType(_: StringType, _) =>
val arr = dataType.asInstanceOf[ArrayType]
arr.elementType.asInstanceOf[StringType].collationId
}

override def nullSafeEval(input: Any): Any = dataType match {
case _: StringType =>
getCollationKey(input.asInstanceOf[UTF8String])
case ArrayType(_: StringType, _) =>
input match {
case arr: Array[UTF8String] =>
arr.map(getCollationKey)
case arr: GenericArrayData =>
val result = new Array[UTF8String](arr.numElements())
for (i <- 0 until arr.numElements()) {
result(i) = getCollationKey(arr.getUTF8String(i))
}
new GenericArrayData(result)
case _ =>
None
}
}

def getCollationKey(str: UTF8String): UTF8String = {
if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) {
str
} else if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) {
UTF8String.fromString(str.toString.toLowerCase(Locale.ROOT))
} else {
val collator = CollationFactory.fetchCollation(collationId).collator
val collationKey = collator.getCollationKey(str.toString)
UTF8String.fromBytes(collationKey.toByteArray)
}
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
case _: StringType =>
if (collationId == CollationFactory.UTF8_BINARY_COLLATION_ID) {
defineCodeGen(ctx, ev, c => s"$c")
} else if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) {
defineCodeGen(ctx, ev, c => s"$c.toLowerCase()")
} else {
defineCodeGen(ctx, ev, c => s"UTF8String.fromBytes(CollationFactory.fetchCollation" +
s"($collationId).collator.getCollationKey($c.toString()).toByteArray())")
}
case ArrayType(_: StringType, _) =>
val expr = ctx.addReferenceObj("this", this)
val arrData = ctx.freshName("arrData")
val arrLength = ctx.freshName("arrLength")
val arrResult = ctx.freshName("arrResult")
val arrIndex = ctx.freshName("arrIndex")
nullSafeCodeGen(ctx, ev, eval => {
s"""
|if ($eval instanceof ArrayData) {
| ArrayData $arrData = (ArrayData)$eval;
| int $arrLength = $arrData.numElements();
| UTF8String[] $arrResult = new UTF8String[$arrLength];
| for (int $arrIndex = 0; $arrIndex < $arrLength; $arrIndex++) {
| $arrResult[$arrIndex] = $expr.getCollationKey($arrData.getUTF8String($arrIndex));
| }
| ${ev.value} = new GenericArrayData($arrResult);
|} else {
| ${ev.value} = null;
|}
""".stripMargin
})
}

override protected def withNewChildInternal(newChild: Expression): Expression = {
copy(expr = newChild)
}

override def child: Expression = expr
}
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,13 @@ object ExprUtils extends QueryErrorsBase {
a.failAnalysis(
errorClass = "MISSING_GROUP_BY",
messageParameters = Map.empty)
case e: Attribute if !a.groupingExpressions.exists(_.semanticEquals(e)) =>
case e: Attribute if !a.groupingExpressions.exists(_.semanticEquals(e)) &&
!a.groupingExpressions.exists {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need this change?

Copy link
Contributor Author

@uros-db uros-db Apr 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

without it, it hits the QueryCompilationErrors.columnNotInGroupByClauseError(e)
with error: Job aborted due to stage failure: Task 0 in stage 2.0 failed 1 times, most recent failure: Lost task 0.0 in stage 2.0 (TID 2) (192.168.15.30 executor driver): org.apache.spark.SparkException: [INTERNAL_ERROR] Couldn't find name#17 in [collationkey(name#17)#30,count(1)#19L] SQLSTATE: XX000

with query plan:

Aggregate [collationkey(name#17)], [any_value(name#17, false) AS name#24, count(1) AS count(1)#20L]
+- LogicalRDD [name#17], false

so name#17 is not found in groupingExpressions

case al: Alias => al.child.semanticEquals(CollationKey(e))
case ck: CollationKey => ck.semanticEquals(CollationKey(e))
case other => other.semanticEquals(e)
}
=>
throw QueryCompilationErrors.columnNotInGroupByClauseError(e)
case s: ScalarSubquery
if s.children.nonEmpty && !a.groupingExpressions.exists(_.semanticEquals(s)) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.catalyst.encoders.HashableWeakReference
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, CollationSupport, MapData, SQLOrderingUtil, UnsafeRowUtils}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, CollationSupport, GenericArrayData, MapData, SQLOrderingUtil, UnsafeRowUtils}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -1523,6 +1523,7 @@ object CodeGenerator extends Logging {
classOf[CalendarInterval].getName,
classOf[VariantVal].getName,
classOf[ArrayData].getName,
classOf[GenericArrayData].getName,
classOf[UnsafeArrayData].getName,
classOf[MapData].getName,
classOf[UnsafeMapData].getName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -353,15 +353,10 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
case a: AggregateExpression => a
})
}
val groupByExpressionSeq = Seq(newPlan, cachedPlan).map(_.groupingExpressions)

val Seq(newPlanSupportsHashAggregate, cachedPlanSupportsHashAggregate) =
aggregateExpressionsSeq.zip(groupByExpressionSeq).map {
case (aggregateExpressions, groupByExpressions) =>
Aggregate.supportsHashAggregate(
aggregateExpressions.flatMap(
_.aggregateFunction.aggBufferAttributes), groupByExpressions)
}
aggregateExpressionsSeq.map(aggregateExpressions => Aggregate.supportsHashAggregate(
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)))

newPlanSupportsHashAggregate && cachedPlanSupportsHashAggregate ||
newPlanSupportsHashAggregate == cachedPlanSupportsHashAggregate && {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
EliminateView,
ReplaceExpressions,
RewriteNonCorrelatedExists,
RewriteGroupByCollation,
PullOutGroupingExpressions,
ComputeCurrentTime,
ReplaceCurrentLike(catalogManager),
Expand Down Expand Up @@ -2231,7 +2232,8 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] {
object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
_.containsPattern(DISTINCT_LIKE), ruleId) {
case Distinct(child) => Aggregate(child.output, child.output, child)
case Distinct(child) =>
Aggregate(child.output, child.output, child)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable

import org.apache.spark.sql.catalyst.analysis.CollationKey
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
Expand Down Expand Up @@ -51,7 +52,7 @@ object PullOutGroupingExpressions extends Rule[LogicalPlan] {
case a: Aggregate if a.resolved =>
val complexGroupingExpressionMap = mutable.LinkedHashMap.empty[Expression, NamedExpression]
val newGroupingExpressions = a.groupingExpressions.toIndexedSeq.map {
case e if !e.foldable && e.children.nonEmpty =>
case e if !e.foldable && e.children.nonEmpty && !e.isInstanceOf[CollationKey] =>
complexGroupingExpressionMap
.getOrElseUpdate(e.canonicalized, Alias(e, "_groupingexpression")())
.toAttribute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1232,11 +1232,9 @@ object Aggregate {
schema.forall(f => UnsafeRow.isMutable(f.dataType))
}

def supportsHashAggregate(
aggregateBufferAttributes: Seq[Attribute], groupingExpression: Seq[Expression]): Boolean = {
def supportsHashAggregate(aggregateBufferAttributes: Seq[Attribute]): Boolean = {
val aggregationBufferSchema = DataTypeUtils.fromAttributes(aggregateBufferAttributes)
isAggregateBufferMutable(aggregationBufferSchema) &&
groupingExpression.forall(e => UnsafeRowUtils.isBinaryStable(e.dataType))
isAggregateBufferMutable(aggregationBufferSchema)
}

def supportsObjectHashAggregate(aggregateExpressions: Seq[AggregateExpression]): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ object AggUtils {
resultExpressions: Seq[NamedExpression] = Nil,
child: SparkPlan): SparkPlan = {
val useHash = Aggregate.supportsHashAggregate(
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes), groupingExpressions)
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes))

val forceObjHashAggregate = forceApplyObjectHashAggregate(child.conf)
val forceSortAggregate = forceApplySortAggregate(child.conf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.execution.aggregate

import java.util.Locale
import java.util.concurrent.TimeUnit._

import scala.collection.mutable
Expand All @@ -33,17 +34,19 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.types.DataTypeUtils.toAttributes
import org.apache.spark.sql.catalyst.util.{truncatedString, CollationFactory}
import org.apache.spark.sql.catalyst.util.DateTimeConstants.NANOS_PER_MILLIS
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.vectorized.MutableColumnarRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{CalendarIntervalType, DecimalType, StringType}
import org.apache.spark.sql.types.{CalendarIntervalType, DecimalType, StringType, StructType}
import org.apache.spark.unsafe.KVIterator
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.ArrayImplicits._
import org.apache.spark.util.Utils


/**
* Hash-based aggregate operator that can also fallback to sorting when data exceeds memory size.
*/
Expand All @@ -59,7 +62,7 @@ case class HashAggregateExec(
child: SparkPlan)
extends AggregateCodegenSupport {

require(Aggregate.supportsHashAggregate(aggregateBufferAttributes, groupingExpressions))
require(Aggregate.supportsHashAggregate(aggregateBufferAttributes))

override lazy val allAttributes: AttributeSeq =
child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++
Expand Down Expand Up @@ -87,6 +90,31 @@ case class HashAggregateExec(
}
}

private def collationAwareStringRows(row: InternalRow, schema: StructType): InternalRow = {
val newRow = new Array[Any](schema.length)
for (i <- schema.fields.indices) {
val field = schema.fields(i)
field.dataType match {
case st: StringType =>
val str: UTF8String = row.getUTF8String(i)
val collationId: Int = st.collationId
if (CollationFactory.fetchCollation(collationId).supportsBinaryEquality) {
newRow(i) = str
} else if (collationId == CollationFactory.UTF8_BINARY_LCASE_COLLATION_ID) {
newRow(i) = UTF8String.fromString(str.toString.toLowerCase(Locale.ROOT))
} else {
val collator = CollationFactory.fetchCollation(collationId).collator
val collationKey = collator.getCollationKey(str.toString).toByteArray
newRow(i) = UTF8String.fromString(collationKey.map("%02x" format _).mkString)
}
case _ =>
newRow(i) = row.get(i, field.dataType)
}
}
val project = UnsafeProjection.create(schema)
project(InternalRow.fromSeq(newRow.toIndexedSeq))
}

protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
val peakMemory = longMetric("peakMemory")
Expand All @@ -96,9 +124,9 @@ case class HashAggregateExec(
val numTasksFallBacked = longMetric("numTasksFallBacked")

child.execute().mapPartitionsWithIndex { (partIndex, iter) =>

// val collationAwareIterator = iter.map(row => collationAwareStringRows(row, child.schema))
val beforeAgg = System.nanoTime()
val hasInput = iter.hasNext
val hasInput = iter.hasNext // collationAwareIterator.hasNext
val res = if (!hasInput && groupingExpressions.nonEmpty) {
// This is a grouped aggregate and the input iterator is empty,
// so return an empty iterator.
Expand All @@ -115,7 +143,7 @@ case class HashAggregateExec(
(expressions, inputSchema) =>
MutableProjection.create(expressions, inputSchema),
inputAttributes,
iter,
iter, // collationAwareIterator
testFallbackStartsAt,
numOutputRows,
peakMemory,
Expand Down
Loading