diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 414155537290a..81039a697656e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -18,8 +18,11 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} +import org.apache.spark.sql.execution.datasources.PushableExpression import org.apache.spark.sql.types.{BooleanType, IntegerType} /** @@ -90,6 +93,8 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } case Cast(child, dataType, _, true) => generateExpression(child).map(v => new V2Cast(v, dataType)) + case AggregateExpression(aggregateFunction, Complete, isDistinct, None, _) => + generateAggregateFunc(aggregateFunction, isDistinct) case Abs(child, true) => generateExpressionWithName("ABS", Seq(child)) case Coalesce(children) => generateExpressionWithName("COALESCE", children) case Greatest(children) => generateExpressionWithName("GREATEST", children) @@ -269,6 +274,54 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { case _ => None } + private def generateAggregateFunc( + aggregateFunction: AggregateFunction, + isDistinct: Boolean): Option[AggregateFunc] = aggregateFunction match { + case aggregate.Min(PushableExpression(expr)) => Some(new Min(expr)) + case aggregate.Max(PushableExpression(expr)) => Some(new Max(expr)) + case count: aggregate.Count if count.children.length == 1 => + count.children.head match { + // COUNT(any literal) is the same as COUNT(*) + case Literal(_, _) => Some(new CountStar()) + case PushableExpression(expr) => Some(new Count(expr, isDistinct)) + case _ => None + } + case aggregate.Sum(PushableExpression(expr), _) => Some(new Sum(expr, isDistinct)) + case aggregate.Average(PushableExpression(expr), _) => Some(new Avg(expr, isDistinct)) + case aggregate.VariancePop(PushableExpression(expr), _) => + Some(new GeneralAggregateFunc("VAR_POP", isDistinct, Array(expr))) + case aggregate.VarianceSamp(PushableExpression(expr), _) => + Some(new GeneralAggregateFunc("VAR_SAMP", isDistinct, Array(expr))) + case aggregate.StddevPop(PushableExpression(expr), _) => + Some(new GeneralAggregateFunc("STDDEV_POP", isDistinct, Array(expr))) + case aggregate.StddevSamp(PushableExpression(expr), _) => + Some(new GeneralAggregateFunc("STDDEV_SAMP", isDistinct, Array(expr))) + case aggregate.CovPopulation(PushableExpression(left), PushableExpression(right), _) => + Some(new GeneralAggregateFunc("COVAR_POP", isDistinct, Array(left, right))) + case aggregate.CovSample(PushableExpression(left), PushableExpression(right), _) => + Some(new GeneralAggregateFunc("COVAR_SAMP", isDistinct, Array(left, right))) + case aggregate.Corr(PushableExpression(left), PushableExpression(right), _) => + Some(new GeneralAggregateFunc("CORR", isDistinct, Array(left, right))) + case aggregate.RegrIntercept(PushableExpression(left), PushableExpression(right)) => + Some(new GeneralAggregateFunc("REGR_INTERCEPT", isDistinct, Array(left, right))) + case aggregate.RegrR2(PushableExpression(left), PushableExpression(right)) => + Some(new GeneralAggregateFunc("REGR_R2", isDistinct, Array(left, right))) + case aggregate.RegrSlope(PushableExpression(left), PushableExpression(right)) => + Some(new GeneralAggregateFunc("REGR_SLOPE", isDistinct, Array(left, right))) + case aggregate.RegrSXY(PushableExpression(left), PushableExpression(right)) => + Some(new GeneralAggregateFunc("REGR_SXY", isDistinct, Array(left, right))) + // TODO supports other aggregate functions + case aggregate.V2Aggregator(aggrFunc, children, _, _) => + val translatedExprs = children.flatMap(PushableExpression.unapply(_)) + if (translatedExprs.length == children.length) { + Some(new UserDefinedAggregateFunc(aggrFunc.name(), + aggrFunc.canonicalName(), isDistinct, translatedExprs.toArray[V2Expression])) + } else { + None + } + case _ => None + } + private def flipComparisonOperatorName(operatorName: String): String = { operatorName match { case ">" => "<" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 27e4e1461773a..a9d5c6da3844c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.util.{ResolveDefaultColumns, V2ExpressionBu import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue} -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -694,57 +694,6 @@ object DataSourceStrategy (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters) } - protected[sql] def translateAggregate(agg: AggregateExpression): Option[AggregateFunc] = { - if (agg.filter.isEmpty) { - agg.aggregateFunction match { - case aggregate.Min(PushableExpression(expr)) => Some(new Min(expr)) - case aggregate.Max(PushableExpression(expr)) => Some(new Max(expr)) - case count: aggregate.Count if count.children.length == 1 => - count.children.head match { - // COUNT(any literal) is the same as COUNT(*) - case Literal(_, _) => Some(new CountStar()) - case PushableExpression(expr) => Some(new Count(expr, agg.isDistinct)) - case _ => None - } - case aggregate.Sum(PushableExpression(expr), _) => Some(new Sum(expr, agg.isDistinct)) - case aggregate.Average(PushableExpression(expr), _) => Some(new Avg(expr, agg.isDistinct)) - case aggregate.VariancePop(PushableExpression(expr), _) => - Some(new GeneralAggregateFunc("VAR_POP", agg.isDistinct, Array(expr))) - case aggregate.VarianceSamp(PushableExpression(expr), _) => - Some(new GeneralAggregateFunc("VAR_SAMP", agg.isDistinct, Array(expr))) - case aggregate.StddevPop(PushableExpression(expr), _) => - Some(new GeneralAggregateFunc("STDDEV_POP", agg.isDistinct, Array(expr))) - case aggregate.StddevSamp(PushableExpression(expr), _) => - Some(new GeneralAggregateFunc("STDDEV_SAMP", agg.isDistinct, Array(expr))) - case aggregate.CovPopulation(PushableExpression(left), PushableExpression(right), _) => - Some(new GeneralAggregateFunc("COVAR_POP", agg.isDistinct, Array(left, right))) - case aggregate.CovSample(PushableExpression(left), PushableExpression(right), _) => - Some(new GeneralAggregateFunc("COVAR_SAMP", agg.isDistinct, Array(left, right))) - case aggregate.Corr(PushableExpression(left), PushableExpression(right), _) => - Some(new GeneralAggregateFunc("CORR", agg.isDistinct, Array(left, right))) - case aggregate.RegrIntercept(PushableExpression(left), PushableExpression(right)) => - Some(new GeneralAggregateFunc("REGR_INTERCEPT", agg.isDistinct, Array(left, right))) - case aggregate.RegrR2(PushableExpression(left), PushableExpression(right)) => - Some(new GeneralAggregateFunc("REGR_R2", agg.isDistinct, Array(left, right))) - case aggregate.RegrSlope(PushableExpression(left), PushableExpression(right)) => - Some(new GeneralAggregateFunc("REGR_SLOPE", agg.isDistinct, Array(left, right))) - case aggregate.RegrSXY(PushableExpression(left), PushableExpression(right)) => - Some(new GeneralAggregateFunc("REGR_SXY", agg.isDistinct, Array(left, right))) - case aggregate.V2Aggregator(aggrFunc, children, _, _) => - val translatedExprs = children.flatMap(PushableExpression.unapply(_)) - if (translatedExprs.length == children.length) { - Some(new UserDefinedAggregateFunc(aggrFunc.name(), - aggrFunc.canonicalName(), agg.isDistinct, translatedExprs.toArray[V2Expression])) - } else { - None - } - case _ => None - } - } else { - None - } - } - /** * Translate aggregate expressions and group by expressions. * @@ -753,13 +702,13 @@ object DataSourceStrategy protected[sql] def translateAggregation( aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = { - def translateGroupBy(e: Expression): Option[V2Expression] = e match { + def translate(e: Expression): Option[V2Expression] = e match { case PushableExpression(expr) => Some(expr) case _ => None } - val translatedAggregates = aggregates.flatMap(translateAggregate) - val translatedGroupBys = groupBy.flatMap(translateGroupBy) + val translatedAggregates = aggregates.flatMap(translate).asInstanceOf[Seq[AggregateFunc]] + val translatedGroupBys = groupBy.flatMap(translate) if (translatedAggregates.length != aggregates.length || translatedGroupBys.length != groupBy.length) {