Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion python/pyspark/sql/tests/test_pandas_cogrouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import sys

from pyspark.sql.functions import array, explode, col, lit, udf, sum, pandas_udf, PandasUDFType
from pyspark.sql.types import DoubleType, StructType, StructField
from pyspark.sql.types import DoubleType, StructType, StructField, Row
from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \
pandas_requirement_message, pyarrow_requirement_message
from pyspark.testing.utils import QuietTest
Expand Down Expand Up @@ -193,6 +193,22 @@ def test_wrong_args(self):
left.groupby('id').cogroup(right.groupby('id')) \
.applyInPandas(lambda: 1, StructType([StructField("d", DoubleType())]))

def test_case_insensitive_grouping_column(self):
# SPARK-31915: case-insensitive grouping column should work.
df1 = self.spark.createDataFrame([(1, 1)], ("column", "value"))

row = df1.groupby("ColUmn").cogroup(
df1.groupby("COLUMN")
).applyInPandas(lambda r, l: r + l, "column long, value long").first()
self.assertEquals(row.asDict(), Row(column=2, value=2).asDict())

df2 = self.spark.createDataFrame([(1, 1)], ("column", "value"))

row = df1.groupby("ColUmn").cogroup(
df2.groupby("COLUMN")
).applyInPandas(lambda r, l: r + l, "column long, value long").first()
self.assertEquals(row.asDict(), Row(column=2, value=2).asDict())

@staticmethod
def _test_with_key(left, right, isLeft):

Expand Down
10 changes: 10 additions & 0 deletions python/pyspark/sql/tests/test_pandas_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,16 @@ def f(key, pdf):
# Check that all group and window_range values from udf matched expected
self.assertTrue(all([r[0] for r in result]))

def test_case_insensitive_grouping_column(self):
# SPARK-31915: case-insensitive grouping column should work.
def my_pandas_udf(pdf):
return pdf.assign(score=0.5)

df = self.spark.createDataFrame([[1, 1]], ["column", "score"])
row = df.groupby('COLUMN').applyInPandas(
my_pandas_udf, schema="column integer, score float").first()
self.assertEquals(row.asDict(), Row(column=1, score=0.5).asDict())


if __name__ == "__main__":
from pyspark.sql.tests.test_pandas_grouped_map import *
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -546,9 +546,10 @@ class RelationalGroupedDataset protected[sql](
case ne: NamedExpression => ne
case other => Alias(other, other.toString)()
}
val groupingAttributes = groupingNamedExpressions.map(_.toAttribute)
val child = df.logicalPlan
val project = Project(groupingNamedExpressions ++ child.output, child)
val project = df.sparkSession.sessionState.executePlan(
Project(groupingNamedExpressions ++ child.output, child)).analyzed
val groupingAttributes = project.output.take(groupingNamedExpressions.length)
val output = expr.dataType.asInstanceOf[StructType].toAttributes
val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project)

Expand Down Expand Up @@ -583,14 +584,16 @@ class RelationalGroupedDataset protected[sql](
case other => Alias(other, other.toString)()
}

val leftAttributes = leftGroupingNamedExpressions.map(_.toAttribute)
val rightAttributes = rightGroupingNamedExpressions.map(_.toAttribute)

val leftChild = df.logicalPlan
val rightChild = r.df.logicalPlan

val left = Project(leftGroupingNamedExpressions ++ leftChild.output, leftChild)
val right = Project(rightGroupingNamedExpressions ++ rightChild.output, rightChild)
val left = df.sparkSession.sessionState.executePlan(
Project(leftGroupingNamedExpressions ++ leftChild.output, leftChild)).analyzed
val right = r.df.sparkSession.sessionState.executePlan(
Project(rightGroupingNamedExpressions ++ rightChild.output, rightChild)).analyzed

val leftAttributes = left.output.take(leftGroupingNamedExpressions.length)
val rightAttributes = right.output.take(rightGroupingNamedExpressions.length)

val output = expr.dataType.asInstanceOf[StructType].toAttributes
val plan = FlatMapCoGroupsInPandas(leftAttributes, rightAttributes, expr, output, left, right)
Expand Down