Skip to content

Commit 00d06ca

Browse files
HyukjinKwonBryanCutler
authored andcommitted
[SPARK-31915][SQL][PYTHON] Resolve the grouping column properly per the case sensitivity in grouped and cogrouped pandas UDFs
### What changes were proposed in this pull request? This is another approach to fix the issue. See the previous try #28745. It was too invasive so I took more conservative approach. This PR proposes to resolve grouping attributes separately first so it can be properly referred when `FlatMapGroupsInPandas` and `FlatMapCoGroupsInPandas` are resolved without ambiguity. Previously, ```python from pyspark.sql.functions import * df = spark.createDataFrame([[1, 1]], ["column", "Score"]) pandas_udf("column integer, Score float", PandasUDFType.GROUPED_MAP) def my_pandas_udf(pdf): return pdf.assign(Score=0.5) df.groupby('COLUMN').apply(my_pandas_udf).show() ``` was failed as below: ``` pyspark.sql.utils.AnalysisException: "Reference 'COLUMN' is ambiguous, could be: COLUMN, COLUMN.;" ``` because the unresolved `COLUMN` in `FlatMapGroupsInPandas` doesn't know which reference to take from the child projection. After this fix, it resolves the child projection first with grouping keys and pass, to `FlatMapGroupsInPandas`, the attribute as a grouping key from the child projection that is positionally selected. ### Why are the changes needed? To resolve grouping keys correctly. ### Does this PR introduce _any_ user-facing change? Yes, ```python from pyspark.sql.functions import * df = spark.createDataFrame([[1, 1]], ["column", "Score"]) pandas_udf("column integer, Score float", PandasUDFType.GROUPED_MAP) def my_pandas_udf(pdf): return pdf.assign(Score=0.5) df.groupby('COLUMN').apply(my_pandas_udf).show() ``` ```python df1 = spark.createDataFrame([(1, 1)], ("column", "value")) df2 = spark.createDataFrame([(1, 1)], ("column", "value")) df1.groupby("COLUMN").cogroup( df2.groupby("COLUMN") ).applyInPandas(lambda r, l: r + l, df1.schema).show() ``` Before: ``` pyspark.sql.utils.AnalysisException: Reference 'COLUMN' is ambiguous, could be: COLUMN, COLUMN.; ``` ``` pyspark.sql.utils.AnalysisException: cannot resolve '`COLUMN`' given input columns: [COLUMN, COLUMN, value, value];; 'FlatMapCoGroupsInPandas ['COLUMN], ['COLUMN], <lambda>(column#9L, value#10L, column#13L, value#14L), [column#22L, value#23L] :- Project [COLUMN#9L, column#9L, value#10L] : +- LogicalRDD [column#9L, value#10L], false +- Project [COLUMN#13L, column#13L, value#14L] +- LogicalRDD [column#13L, value#14L], false ``` After: ``` +------+-----+ |column|Score| +------+-----+ | 1| 0.5| +------+-----+ ``` ``` +------+-----+ |column|value| +------+-----+ | 2| 2| +------+-----+ ``` ### How was this patch tested? Unittests were added and manually tested. Closes #28777 from HyukjinKwon/SPARK-31915-another. Authored-by: HyukjinKwon <[email protected]> Signed-off-by: Bryan Cutler <[email protected]>
1 parent 2ab82fa commit 00d06ca

File tree

3 files changed

+37
-8
lines changed

3 files changed

+37
-8
lines changed

python/pyspark/sql/tests/test_pandas_cogrouped_map.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import sys
2020

2121
from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType
22-
from pyspark.sql.types import DoubleType, StructType, StructField
22+
from pyspark.sql.types import DoubleType, StructType, StructField, Row
2323
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
2424
pandas_requirement_message, pyarrow_requirement_message
2525
from pyspark.testing.utils import QuietTest
@@ -193,6 +193,22 @@ def test_wrong_args(self):
193193
left.groupby('id').cogroup(right.groupby('id')) \
194194
.applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())]))
195195

196+
def test_case_insensitive_grouping_column(self):
197+
# SPARK-31915: case-insensitive grouping column should work.
198+
df1 = self.spark.createDataFrame([(1, 1)], ("column", "value"))
199+
200+
row = df1.groupby("ColUmn").cogroup(
201+
df1.groupby("COLUMN")
202+
).applyInPandas(lambda r, l: r + l, "column long, value long").first()
203+
self.assertEquals(row.asDict(), Row(column=2, value=2).asDict())
204+
205+
df2 = self.spark.createDataFrame([(1, 1)], ("column", "value"))
206+
207+
row = df1.groupby("ColUmn").cogroup(
208+
df2.groupby("COLUMN")
209+
).applyInPandas(lambda r, l: r + l, "column long, value long").first()
210+
self.assertEquals(row.asDict(), Row(column=2, value=2).asDict())
211+
196212
@staticmethod
197213
def _test_with_key(left, right, isLeft):
198214

python/pyspark/sql/tests/test_pandas_grouped_map.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,6 +587,16 @@ def f(key, pdf):
587587
# Check that all group and window_range values from udf matched expected
588588
self.assertTrue(all([r[0] for r in result]))
589589

590+
def test_case_insensitive_grouping_column(self):
591+
# SPARK-31915: case-insensitive grouping column should work.
592+
def my_pandas_udf(pdf):
593+
return pdf.assign(score=0.5)
594+
595+
df = self.spark.createDataFrame([[1, 1]], ["column", "score"])
596+
row = df.groupby('COLUMN').applyInPandas(
597+
my_pandas_udf, schema="column integer, score float").first()
598+
self.assertEquals(row.asDict(), Row(column=1, score=0.5).asDict())
599+
590600

591601
if __name__ == "__main__":
592602
from pyspark.sql.tests.test_pandas_grouped_map import *

sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -546,9 +546,10 @@ class RelationalGroupedDataset protected[sql](
546546
case ne: NamedExpression => ne
547547
case other => Alias(other, other.toString)()
548548
}
549-
val groupingAttributes = groupingNamedExpressions.map(_.toAttribute)
550549
val child = df.logicalPlan
551-
val project = Project(groupingNamedExpressions ++ child.output, child)
550+
val project = df.sparkSession.sessionState.executePlan(
551+
Project(groupingNamedExpressions ++ child.output, child)).analyzed
552+
val groupingAttributes = project.output.take(groupingNamedExpressions.length)
552553
val output = expr.dataType.asInstanceOf[StructType].toAttributes
553554
val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project)
554555

@@ -583,14 +584,16 @@ class RelationalGroupedDataset protected[sql](
583584
case other => Alias(other, other.toString)()
584585
}
585586

586-
val leftAttributes = leftGroupingNamedExpressions.map(_.toAttribute)
587-
val rightAttributes = rightGroupingNamedExpressions.map(_.toAttribute)
588-
589587
val leftChild = df.logicalPlan
590588
val rightChild = r.df.logicalPlan
591589

592-
val left = Project(leftGroupingNamedExpressions ++ leftChild.output, leftChild)
593-
val right = Project(rightGroupingNamedExpressions ++ rightChild.output, rightChild)
590+
val left = df.sparkSession.sessionState.executePlan(
591+
Project(leftGroupingNamedExpressions ++ leftChild.output, leftChild)).analyzed
592+
val right = r.df.sparkSession.sessionState.executePlan(
593+
Project(rightGroupingNamedExpressions ++ rightChild.output, rightChild)).analyzed
594+
595+
val leftAttributes = left.output.take(leftGroupingNamedExpressions.length)
596+
val rightAttributes = right.output.take(rightGroupingNamedExpressions.length)
594597

595598
val output = expr.dataType.asInstanceOf[StructType].toAttributes
596599
val plan = FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, expr, output, left, right)

0 commit comments

Comments
 (0)