Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5925,6 +5925,22 @@ def test_invalid_args(self):
'mixture.*aggregate function.*group aggregate pandas UDF'):
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()

def test_self_join_with_pandas(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized this test is in a wrong class. This should be moved to GroupedMapPandasUDFTests

import pyspark.sql.functions as F

@F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP)
def dummy_pandas_udf(df):
return df[['key', 'col']]

df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'),
Row(key=2, col='C')])
dfWithPandas = df.groupBy('key').apply(dummy_pandas_udf)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: dfWithPandas -> df_with_pandas


# this was throwing an AnalysisException before SPARK-24208
res = dfWithPandas.alias('temp0').join(dfWithPandas.alias('temp1'),
F.col('temp0.key') == F.col('temp1.key'))
self.assertEquals(res.count(), 5)


@unittest.skipIf(
not _have_pandas or not _have_pyarrow,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,10 @@ class Analyzer(
if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty =>
(oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions)))

case oldVersion @ FlatMapGroupsInPandas(_, _, output, _)
if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
(oldVersion, oldVersion.copy(output = output.map(_.newInstance())))

case oldVersion: Generate
if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty =>
val newOutput = oldVersion.generatorOutput.map(_.newInstance())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,4 +93,16 @@ class GroupedDatasetSuite extends QueryTest with SharedSQLContext {
}
datasetWithUDF.unpersist(true)
}

test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test case should be rewritten and moved to AnalysisSuite

val df = datasetWithUDF.groupBy("s").flatMapGroupsInPandas(PythonUDF(
"pyUDF",
null,
StructType(Seq(StructField("s", LongType))),
Seq.empty,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
true))
val df1 = df.alias("temp0").join(df.alias("temp1"), $"temp0.s" === $"temp1.s")
df1.queryExecution.assertAnalyzed()
}
}