From f10fcdfb82645785997783abb2c51661de4a2828 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Wed, 3 Aug 2022 12:19:26 +0800 Subject: [PATCH 1/4] [SPARK-39964][SQL] DS V2 pushdown should unify the translate API --- .../catalyst/util/V2ExpressionBuilder.scala | 49 +++++++++++++++ .../datasources/DataSourceStrategy.scala | 59 ++----------------- 2 files changed, 53 insertions(+), 55 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 414155537290a..ddb4ecdfbd83d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -18,8 +18,11 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc} +import org.apache.spark.sql.connector.expressions.aggregate.{Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} +import org.apache.spark.sql.execution.datasources.PushableExpression import org.apache.spark.sql.types.{BooleanType, IntegerType} /** @@ -90,6 +93,52 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } case Cast(child, dataType, _, true) => generateExpression(child).map(v => new V2Cast(v, dataType)) + case AggregateExpression(aggregateFunction, _, isDistinct, None, _) => + aggregateFunction match { + case aggregate.Min(PushableExpression(expr)) => Some(new Min(expr)) + case aggregate.Max(PushableExpression(expr)) => Some(new Max(expr)) + case count: aggregate.Count if count.children.length == 1 => + count.children.head match { + // COUNT(any literal) is the same as COUNT(*) + case Literal(_, _) => Some(new CountStar()) + case PushableExpression(expr) => Some(new Count(expr, isDistinct)) + case _ => None + } + case aggregate.Sum(PushableExpression(expr), _) => Some(new Sum(expr, isDistinct)) + case aggregate.Average(PushableExpression(expr), _) => Some(new Avg(expr, isDistinct)) + case aggregate.VariancePop(PushableExpression(expr), _) => + Some(new GeneralAggregateFunc("VAR_POP", isDistinct, Array(expr))) + case aggregate.VarianceSamp(PushableExpression(expr), _) => + Some(new GeneralAggregateFunc("VAR_SAMP", isDistinct, Array(expr))) + case aggregate.StddevPop(PushableExpression(expr), _) => + Some(new GeneralAggregateFunc("STDDEV_POP", isDistinct, Array(expr))) + case aggregate.StddevSamp(PushableExpression(expr), _) => + Some(new GeneralAggregateFunc("STDDEV_SAMP", isDistinct, Array(expr))) + case aggregate.CovPopulation(PushableExpression(left), PushableExpression(right), _) => + Some(new GeneralAggregateFunc("COVAR_POP", isDistinct, Array(left, right))) + case aggregate.CovSample(PushableExpression(left), PushableExpression(right), _) => + Some(new GeneralAggregateFunc("COVAR_SAMP", isDistinct, Array(left, right))) + case aggregate.Corr(PushableExpression(left), PushableExpression(right), _) => + Some(new GeneralAggregateFunc("CORR", isDistinct, Array(left, right))) + case aggregate.RegrIntercept(PushableExpression(left), PushableExpression(right)) => + Some(new GeneralAggregateFunc("REGR_INTERCEPT", isDistinct, Array(left, right))) + case aggregate.RegrR2(PushableExpression(left), PushableExpression(right)) => + Some(new GeneralAggregateFunc("REGR_R2", isDistinct, Array(left, right))) + case aggregate.RegrSlope(PushableExpression(left), PushableExpression(right)) => + Some(new GeneralAggregateFunc("REGR_SLOPE", isDistinct, Array(left, right))) + case aggregate.RegrSXY(PushableExpression(left), PushableExpression(right)) => + Some(new GeneralAggregateFunc("REGR_SXY", isDistinct, Array(left, right))) + // TODO supports other aggregate functions + case aggregate.V2Aggregator(aggrFunc, children, _, _) => + val translatedExprs = children.flatMap(PushableExpression.unapply(_)) + if (translatedExprs.length == children.length) { + Some(new UserDefinedAggregateFunc(aggrFunc.name(), + aggrFunc.canonicalName(), isDistinct, translatedExprs.toArray[V2Expression])) + } else { + None + } + case _ => None + } case Abs(child, true) => generateExpressionWithName("ABS", Seq(child)) case Coalesce(children) => generateExpressionWithName("COALESCE", children) case Greatest(children) => generateExpressionWithName("GREATEST", children) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 27e4e1461773a..a9d5c6da3844c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.util.{ResolveDefaultColumns, V2ExpressionBu import org.apache.spark.sql.connector.catalog.SupportsRead import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, NullOrdering, SortDirection, SortOrder => V2SortOrder, SortValue} -import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Aggregation} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{RowDataSourceScanExec, SparkPlan} import org.apache.spark.sql.execution.command._ @@ -694,57 +694,6 @@ object DataSourceStrategy (nonconvertiblePredicates ++ unhandledPredicates, pushedFilters, handledFilters) } - protected[sql] def translateAggregate(agg: AggregateExpression): Option[AggregateFunc] = { - if (agg.filter.isEmpty) { - agg.aggregateFunction match { - case aggregate.Min(PushableExpression(expr)) => Some(new Min(expr)) - case aggregate.Max(PushableExpression(expr)) => Some(new Max(expr)) - case count: aggregate.Count if count.children.length == 1 => - count.children.head match { - // COUNT(any literal) is the same as COUNT(*) - case Literal(_, _) => Some(new CountStar()) - case PushableExpression(expr) => Some(new Count(expr, agg.isDistinct)) - case _ => None - } - case aggregate.Sum(PushableExpression(expr), _) => Some(new Sum(expr, agg.isDistinct)) - case aggregate.Average(PushableExpression(expr), _) => Some(new Avg(expr, agg.isDistinct)) - case aggregate.VariancePop(PushableExpression(expr), _) => - Some(new GeneralAggregateFunc("VAR_POP", agg.isDistinct, Array(expr))) - case aggregate.VarianceSamp(PushableExpression(expr), _) => - Some(new GeneralAggregateFunc("VAR_SAMP", agg.isDistinct, Array(expr))) - case aggregate.StddevPop(PushableExpression(expr), _) => - Some(new GeneralAggregateFunc("STDDEV_POP", agg.isDistinct, Array(expr))) - case aggregate.StddevSamp(PushableExpression(expr), _) => - Some(new GeneralAggregateFunc("STDDEV_SAMP", agg.isDistinct, Array(expr))) - case aggregate.CovPopulation(PushableExpression(left), PushableExpression(right), _) => - Some(new GeneralAggregateFunc("COVAR_POP", agg.isDistinct, Array(left, right))) - case aggregate.CovSample(PushableExpression(left), PushableExpression(right), _) => - Some(new GeneralAggregateFunc("COVAR_SAMP", agg.isDistinct, Array(left, right))) - case aggregate.Corr(PushableExpression(left), PushableExpression(right), _) => - Some(new GeneralAggregateFunc("CORR", agg.isDistinct, Array(left, right))) - case aggregate.RegrIntercept(PushableExpression(left), PushableExpression(right)) => - Some(new GeneralAggregateFunc("REGR_INTERCEPT", agg.isDistinct, Array(left, right))) - case aggregate.RegrR2(PushableExpression(left), PushableExpression(right)) => - Some(new GeneralAggregateFunc("REGR_R2", agg.isDistinct, Array(left, right))) - case aggregate.RegrSlope(PushableExpression(left), PushableExpression(right)) => - Some(new GeneralAggregateFunc("REGR_SLOPE", agg.isDistinct, Array(left, right))) - case aggregate.RegrSXY(PushableExpression(left), PushableExpression(right)) => - Some(new GeneralAggregateFunc("REGR_SXY", agg.isDistinct, Array(left, right))) - case aggregate.V2Aggregator(aggrFunc, children, _, _) => - val translatedExprs = children.flatMap(PushableExpression.unapply(_)) - if (translatedExprs.length == children.length) { - Some(new UserDefinedAggregateFunc(aggrFunc.name(), - aggrFunc.canonicalName(), agg.isDistinct, translatedExprs.toArray[V2Expression])) - } else { - None - } - case _ => None - } - } else { - None - } - } - /** * Translate aggregate expressions and group by expressions. * @@ -753,13 +702,13 @@ object DataSourceStrategy protected[sql] def translateAggregation( aggregates: Seq[AggregateExpression], groupBy: Seq[Expression]): Option[Aggregation] = { - def translateGroupBy(e: Expression): Option[V2Expression] = e match { + def translate(e: Expression): Option[V2Expression] = e match { case PushableExpression(expr) => Some(expr) case _ => None } - val translatedAggregates = aggregates.flatMap(translateAggregate) - val translatedGroupBys = groupBy.flatMap(translateGroupBy) + val translatedAggregates = aggregates.flatMap(translate).asInstanceOf[Seq[AggregateFunc]] + val translatedGroupBys = groupBy.flatMap(translate) if (translatedAggregates.length != aggregates.length || translatedGroupBys.length != groupBy.length) { From 51399f09757146e6b87b97a6ca6f36510358c982 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 4 Aug 2022 08:25:40 +0800 Subject: [PATCH 2/4] Update code --- .../catalyst/util/V2ExpressionBuilder.scala | 98 ++++++++++--------- 1 file changed, 51 insertions(+), 47 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index ddb4ecdfbd83d..6b209fe407a68 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc} -import org.apache.spark.sql.connector.expressions.aggregate.{Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} +import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate} import org.apache.spark.sql.execution.datasources.PushableExpression import org.apache.spark.sql.types.{BooleanType, IntegerType} @@ -94,51 +94,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { case Cast(child, dataType, _, true) => generateExpression(child).map(v => new V2Cast(v, dataType)) case AggregateExpression(aggregateFunction, _, isDistinct, None, _) => - aggregateFunction match { - case aggregate.Min(PushableExpression(expr)) => Some(new Min(expr)) - case aggregate.Max(PushableExpression(expr)) => Some(new Max(expr)) - case count: aggregate.Count if count.children.length == 1 => - count.children.head match { - // COUNT(any literal) is the same as COUNT(*) - case Literal(_, _) => Some(new CountStar()) - case PushableExpression(expr) => Some(new Count(expr, isDistinct)) - case _ => None - } - case aggregate.Sum(PushableExpression(expr), _) => Some(new Sum(expr, isDistinct)) - case aggregate.Average(PushableExpression(expr), _) => Some(new Avg(expr, isDistinct)) - case aggregate.VariancePop(PushableExpression(expr), _) => - Some(new GeneralAggregateFunc("VAR_POP", isDistinct, Array(expr))) - case aggregate.VarianceSamp(PushableExpression(expr), _) => - Some(new GeneralAggregateFunc("VAR_SAMP", isDistinct, Array(expr))) - case aggregate.StddevPop(PushableExpression(expr), _) => - Some(new GeneralAggregateFunc("STDDEV_POP", isDistinct, Array(expr))) - case aggregate.StddevSamp(PushableExpression(expr), _) => - Some(new GeneralAggregateFunc("STDDEV_SAMP", isDistinct, Array(expr))) - case aggregate.CovPopulation(PushableExpression(left), PushableExpression(right), _) => - Some(new GeneralAggregateFunc("COVAR_POP", isDistinct, Array(left, right))) - case aggregate.CovSample(PushableExpression(left), PushableExpression(right), _) => - Some(new GeneralAggregateFunc("COVAR_SAMP", isDistinct, Array(left, right))) - case aggregate.Corr(PushableExpression(left), PushableExpression(right), _) => - Some(new GeneralAggregateFunc("CORR", isDistinct, Array(left, right))) - case aggregate.RegrIntercept(PushableExpression(left), PushableExpression(right)) => - Some(new GeneralAggregateFunc("REGR_INTERCEPT", isDistinct, Array(left, right))) - case aggregate.RegrR2(PushableExpression(left), PushableExpression(right)) => - Some(new GeneralAggregateFunc("REGR_R2", isDistinct, Array(left, right))) - case aggregate.RegrSlope(PushableExpression(left), PushableExpression(right)) => - Some(new GeneralAggregateFunc("REGR_SLOPE", isDistinct, Array(left, right))) - case aggregate.RegrSXY(PushableExpression(left), PushableExpression(right)) => - Some(new GeneralAggregateFunc("REGR_SXY", isDistinct, Array(left, right))) - // TODO supports other aggregate functions - case aggregate.V2Aggregator(aggrFunc, children, _, _) => - val translatedExprs = children.flatMap(PushableExpression.unapply(_)) - if (translatedExprs.length == children.length) { - Some(new UserDefinedAggregateFunc(aggrFunc.name(), - aggrFunc.canonicalName(), isDistinct, translatedExprs.toArray[V2Expression])) - } else { - None - } - case _ => None - } + generateAggregateFunc(aggregateFunction, isDistinct) case Abs(child, true) => generateExpressionWithName("ABS", Seq(child)) case Coalesce(children) => generateExpressionWithName("COALESCE", children) case Greatest(children) => generateExpressionWithName("GREATEST", children) @@ -318,6 +274,54 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { case _ => None } + private def generateAggregateFunc( + aggregateFunction: AggregateFunction, + isDistinct: Boolean): Option[AggregateFunc] = aggregateFunction match { + case aggregate.Min(PushableExpression(expr)) => Some(new Min(expr)) + case aggregate.Max(PushableExpression(expr)) => Some(new Max(expr)) + case count: aggregate.Count if count.children.length == 1 => + count.children.head match { + // COUNT(any literal) is the same as COUNT(*) + case Literal(_, _) => Some(new CountStar()) + case PushableExpression(expr) => Some(new Count(expr, isDistinct)) + case _ => None + } + case aggregate.Sum(PushableExpression(expr), _) => Some(new Sum(expr, isDistinct)) + case aggregate.Average(PushableExpression(expr), _) => Some(new Avg(expr, isDistinct)) + case aggregate.VariancePop(PushableExpression(expr), _) => + Some(new GeneralAggregateFunc("VAR_POP", isDistinct, Array(expr))) + case aggregate.VarianceSamp(PushableExpression(expr), _) => + Some(new GeneralAggregateFunc("VAR_SAMP", isDistinct, Array(expr))) + case aggregate.StddevPop(PushableExpression(expr), _) => + Some(new GeneralAggregateFunc("STDDEV_POP", isDistinct, Array(expr))) + case aggregate.StddevSamp(PushableExpression(expr), _) => + Some(new GeneralAggregateFunc("STDDEV_SAMP", isDistinct, Array(expr))) + case aggregate.CovPopulation(PushableExpression(left), PushableExpression(right), _) => + Some(new GeneralAggregateFunc("COVAR_POP", isDistinct, Array(left, right))) + case aggregate.CovSample(PushableExpression(left), PushableExpression(right), _) => + Some(new GeneralAggregateFunc("COVAR_SAMP", isDistinct, Array(left, right))) + case aggregate.Corr(PushableExpression(left), PushableExpression(right), _) => + Some(new GeneralAggregateFunc("CORR", isDistinct, Array(left, right))) + case aggregate.RegrIntercept(PushableExpression(left), PushableExpression(right)) => + Some(new GeneralAggregateFunc("REGR_INTERCEPT", isDistinct, Array(left, right))) + case aggregate.RegrR2(PushableExpression(left), PushableExpression(right)) => + Some(new GeneralAggregateFunc("REGR_R2", isDistinct, Array(left, right))) + case aggregate.RegrSlope(PushableExpression(left), PushableExpression(right)) => + Some(new GeneralAggregateFunc("REGR_SLOPE", isDistinct, Array(left, right))) + case aggregate.RegrSXY(PushableExpression(left), PushableExpression(right)) => + Some(new GeneralAggregateFunc("REGR_SXY", isDistinct, Array(left, right))) + // TODO supports other aggregate functions + case aggregate.V2Aggregator(aggrFunc, children, _, _) => + val translatedExprs = children.flatMap(PushableExpression.unapply(_)) + if (translatedExprs.length == children.length) { + Some(new UserDefinedAggregateFunc(aggrFunc.name(), + aggrFunc.canonicalName(), isDistinct, translatedExprs.toArray[V2Expression])) + } else { + None + } + case _ => None + } + private def flipComparisonOperatorName(operatorName: String): String = { operatorName match { case ">" => "<" From 84365ecf6488c3722a7365f0382ce84733a37fbc Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 4 Aug 2022 10:54:14 +0800 Subject: [PATCH 3/4] Update sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala Co-authored-by: Wenchen Fan --- .../apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 6b209fe407a68..875b31e12136e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -93,7 +93,7 @@ class V2ExpressionBuilder(e: Expression, isPredicate: Boolean = false) { } case Cast(child, dataType, _, true) => generateExpression(child).map(v => new V2Cast(v, dataType)) - case AggregateExpression(aggregateFunction, _, isDistinct, None, _) => + case AggregateExpression(aggregateFunction, Complete, isDistinct, None, _) => generateAggregateFunc(aggregateFunction, isDistinct) case Abs(child, true) => generateExpressionWithName("ABS", Seq(child)) case Coalesce(children) => generateExpressionWithName("COALESCE", children) From 328539d75e6526cb254e4beb2995baa49a8d7e80 Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Thu, 4 Aug 2022 11:58:56 +0800 Subject: [PATCH 4/4] Update code --- .../apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala index 875b31e12136e..81039a697656e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/util/V2ExpressionBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Complete} import org.apache.spark.sql.connector.expressions.{Cast => V2Cast, Expression => V2Expression, Extract => V2Extract, FieldReference, GeneralScalarExpression, LiteralValue, UserDefinedScalarFunc} import org.apache.spark.sql.connector.expressions.aggregate.{AggregateFunc, Avg, Count, CountStar, GeneralAggregateFunc, Max, Min, Sum, UserDefinedAggregateFunc} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, And => V2And, Not => V2Not, Or => V2Or, Predicate => V2Predicate}