1818package org .apache .spark .sql .execution .python
1919
2020import scala .collection .mutable
21+ import scala .collection .mutable .ArrayBuffer
2122
22- import org .apache .spark .sql .catalyst .expressions .{AttributeReference , Expression }
23+ import org .apache .spark .sql .catalyst .expressions ._
24+ import org .apache .spark .sql .catalyst .expressions .aggregate .AggregateExpression
25+ import org .apache .spark .sql .catalyst .plans .logical .{Aggregate , LogicalPlan , Project }
2326import org .apache .spark .sql .catalyst .rules .Rule
2427import org .apache .spark .sql .execution
2528import org .apache .spark .sql .execution .SparkPlan
2629
30+
31+ /**
32+ * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or
33+ * grouping key, evaluate them after aggregate.
34+ */
35+ private [spark] object ExtractPythonUDFFromAggregate extends Rule [LogicalPlan ] {
36+
37+ /**
38+ * Returns whether the expression could only be evaluated within aggregate.
39+ */
40+ private def belongAggregate (e : Expression , agg : Aggregate ): Boolean = {
41+ e.isInstanceOf [AggregateExpression ] ||
42+ agg.groupingExpressions.exists(_.semanticEquals(e))
43+ }
44+
45+ private def hasPythonUdfOverAggregate (expr : Expression , agg : Aggregate ): Boolean = {
46+ expr.find {
47+ e => e.isInstanceOf [PythonUDF ] && e.find(belongAggregate(_, agg)).isDefined
48+ }.isDefined
49+ }
50+
51+ private def extract (agg : Aggregate ): LogicalPlan = {
52+ val projList = new ArrayBuffer [NamedExpression ]()
53+ val aggExpr = new ArrayBuffer [NamedExpression ]()
54+ agg.aggregateExpressions.foreach { expr =>
55+ if (hasPythonUdfOverAggregate(expr, agg)) {
56+ // Python UDF can only be evaluated after aggregate
57+ val newE = expr transformDown {
58+ case e : Expression if belongAggregate(e, agg) =>
59+ val alias = e match {
60+ case a : NamedExpression => a
61+ case o => Alias (e, " agg" )()
62+ }
63+ aggExpr += alias
64+ alias.toAttribute
65+ }
66+ projList += newE.asInstanceOf [NamedExpression ]
67+ } else {
68+ aggExpr += expr
69+ projList += expr.toAttribute
70+ }
71+ }
72+ // There is no Python UDF over aggregate expression
73+ Project (projList, agg.copy(aggregateExpressions = aggExpr))
74+ }
75+
76+ def apply (plan : LogicalPlan ): LogicalPlan = plan transformUp {
77+ case agg : Aggregate if agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg)) =>
78+ extract(agg)
79+ }
80+ }
81+
82+
2783/**
2884 * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
2985 * alone in a batch.
@@ -59,10 +115,12 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
59115 }
60116
61117 /**
62- * Extract all the PythonUDFs from the current operator.
118+ * Extract all the PythonUDFs from the current operator and evaluate them before the operator .
63119 */
64- def extract (plan : SparkPlan ): SparkPlan = {
120+ private def extract (plan : SparkPlan ): SparkPlan = {
65121 val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
122+ // ignore the PythonUDF that come from second/third aggregate, which is not used
123+ .filter(udf => udf.references.subsetOf(plan.inputSet))
66124 if (udfs.isEmpty) {
67125 // If there aren't any, we are done.
68126 plan
@@ -89,11 +147,7 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
89147 // Other cases are disallowed as they are ambiguous or would require a cartesian
90148 // product.
91149 udfs.filterNot(attributeMap.contains).foreach { udf =>
92- if (udf.references.subsetOf(plan.inputSet)) {
93- sys.error(s " Invalid PythonUDF $udf, requires attributes from more than one child. " )
94- } else {
95- sys.error(s " Unable to evaluate PythonUDF $udf. Missing input attributes. " )
96- }
150+ sys.error(s " Invalid PythonUDF $udf, requires attributes from more than one child. " )
97151 }
98152
99153 val rewritten = plan.transformExpressions {
0 commit comments