diff --git a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py index 3ed9d2ac62fd..c1cb30c3caa9 100644 --- a/python/pyspark/sql/tests/test_pandas_cogrouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_cogrouped_map.py @@ -19,7 +19,7 @@ import sys from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType -from pyspark.sql.types import DoubleType, StructType, StructField +from pyspark.sql.types import DoubleType, StructType, StructField, Row from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message from pyspark.testing.utils import QuietTest @@ -193,6 +193,22 @@ def test_wrong_args(self): left.groupby('id').cogroup(right.groupby('id')) \ .applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())])) + def test_case_insensitive_grouping_column(self): + # SPARK-31915: case-insensitive grouping column should work. + df1 = self.spark.createDataFrame([(1, 1)], ("column", "value")) + + row = df1.groupby("ColUmn").cogroup( + df1.groupby("COLUMN") + ).applyInPandas(lambda r, l: r + l, "column long, value long").first() + self.assertEquals(row.asDict(), Row(column=2, value=2).asDict()) + + df2 = self.spark.createDataFrame([(1, 1)], ("column", "value")) + + row = df1.groupby("ColUmn").cogroup( + df2.groupby("COLUMN") + ).applyInPandas(lambda r, l: r + l, "column long, value long").first() + self.assertEquals(row.asDict(), Row(column=2, value=2).asDict()) + @staticmethod def _test_with_key(left, right, isLeft): diff --git a/python/pyspark/sql/tests/test_pandas_grouped_map.py b/python/pyspark/sql/tests/test_pandas_grouped_map.py index ff53a0c6f2cf..76119432662b 100644 --- a/python/pyspark/sql/tests/test_pandas_grouped_map.py +++ b/python/pyspark/sql/tests/test_pandas_grouped_map.py @@ -587,6 +587,16 @@ def f(key, pdf): # Check that all group and window_range values from udf matched expected self.assertTrue(all([r[0] for r in result])) + def test_case_insensitive_grouping_column(self): + # SPARK-31915: case-insensitive grouping column should work. + def my_pandas_udf(pdf): + return pdf.assign(score=0.5) + + df = self.spark.createDataFrame([[1, 1]], ["column", "score"]) + row = df.groupby('COLUMN').applyInPandas( + my_pandas_udf, schema="column integer, score float").first() + self.assertEquals(row.asDict(), Row(column=1, score=0.5).asDict()) + if __name__ == "__main__": from pyspark.sql.tests.test_pandas_grouped_map import * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index b1ba7d453873..c37d8eaa294b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -546,9 +546,10 @@ class RelationalGroupedDataset protected[sql]( case ne: NamedExpression => ne case other => Alias(other, other.toString)() } - val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) val child = df.logicalPlan - val project = Project(groupingNamedExpressions ++ child.output, child) + val project = df.sparkSession.sessionState.executePlan( + Project(groupingNamedExpressions ++ child.output, child)).analyzed + val groupingAttributes = project.output.take(groupingNamedExpressions.length) val output = expr.dataType.asInstanceOf[StructType].toAttributes val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project) @@ -583,14 +584,16 @@ class RelationalGroupedDataset protected[sql]( case other => Alias(other, other.toString)() } - val leftAttributes = leftGroupingNamedExpressions.map(_.toAttribute) - val rightAttributes = rightGroupingNamedExpressions.map(_.toAttribute) - val leftChild = df.logicalPlan val rightChild = r.df.logicalPlan - val left = Project(leftGroupingNamedExpressions ++ leftChild.output, leftChild) - val right = Project(rightGroupingNamedExpressions ++ rightChild.output, rightChild) + val left = df.sparkSession.sessionState.executePlan( + Project(leftGroupingNamedExpressions ++ leftChild.output, leftChild)).analyzed + val right = r.df.sparkSession.sessionState.executePlan( + Project(rightGroupingNamedExpressions ++ rightChild.output, rightChild)).analyzed + + val leftAttributes = left.output.take(leftGroupingNamedExpressions.length) + val rightAttributes = right.output.take(rightGroupingNamedExpressions.length) val output = expr.dataType.asInstanceOf[StructType].toAttributes val plan = FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, expr, output, left, right)