From 5f325a4d4c420800c08293a135c935bb03c48c9e Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 9 Jul 2018 17:11:18 +0200 Subject: [PATCH 1/4] [SPARK-24208][SQL] Fix attribute deduplication for FlatMapGroupsInPandas --- .../spark/sql/catalyst/analysis/Analyzer.scala | 4 ++++ .../org/apache/spark/sql/GroupedDatasetSuite.scala | 12 ++++++++++++ 2 files changed, 16 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index e187133d03b17..9976b92fbc08f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -738,6 +738,10 @@ class Analyzer( if findAliases(aggregateExpressions).intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) + case oldVersion @ FlatMapGroupsInPandas(_, _, output, _) + if AttributeSet(output).intersect(conflictingAttributes).nonEmpty => + (oldVersion, oldVersion.copy(output = output.map(_.newInstance()))) + case oldVersion: Generate if oldVersion.producedAttributes.intersect(conflictingAttributes).nonEmpty => val newOutput = oldVersion.generatorOutput.map(_.newInstance()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala index 147c0b61f5017..bd54ea415ca88 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala @@ -93,4 +93,16 @@ class GroupedDatasetSuite extends QueryTest with SharedSQLContext { } datasetWithUDF.unpersist(true) } + + test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") { + val df = datasetWithUDF.groupBy("s").flatMapGroupsInPandas(PythonUDF( + "pyUDF", + null, + StructType(Seq(StructField("s", LongType))), + Seq.empty, + PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF, + true)) + val df1 = df.alias("temp0").join(df.alias("temp1"), $"temp0.s" === $"temp1.s") + df1.queryExecution.assertAnalyzed() + } } From 032fef003e3f8ba86f05e5b7dd5844d948dda246 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 9 Jul 2018 18:15:34 +0200 Subject: [PATCH 2/4] add python test --- python/pyspark/sql/tests.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 8d738069adb3d..6e11e45f5e894 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5925,6 +5925,22 @@ def test_invalid_args(self): 'mixture.*aggregate function.*group aggregate pandas UDF'): df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect() + def test_self_join_with_pandas(self): + import pyspark.sql.functions as F + + @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) + def dummy_pandas_udf(df): + return df[['key','col']] + + df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'), + Row(key=2, col='C')]) + dfWithPandas = df.groupBy('key').apply(dummy_pandas_udf) + + # this was throwing an AnalysisException before SPARK-24208 + res = dfWithPandas.alias('temp0').join(dfWithPandas.alias('temp1'), + F.col('temp0.key') == F.col('temp1.key')) + self.assertEquals(res.count(), 5) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, From 11e9f0f2e4f84b215e969d5a319fd6e8fa070ca0 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Mon, 9 Jul 2018 19:37:26 +0200 Subject: [PATCH 3/4] fix python style --- python/pyspark/sql/tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6e11e45f5e894..4404dbe40590a 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -5930,7 +5930,7 @@ def test_self_join_with_pandas(self): @F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP) def dummy_pandas_udf(df): - return df[['key','col']] + return df[['key', 'col']] df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'), Row(key=2, col='C')]) From a15949ba0de5ab40d8fb27c155018996d1aed549 Mon Sep 17 00:00:00 2001 From: Marco Gaido Date: Tue, 10 Jul 2018 10:24:52 +0200 Subject: [PATCH 4/4] address comment --- .../scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9976b92fbc08f..c078efdfc0000 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -739,7 +739,7 @@ class Analyzer( (oldVersion, oldVersion.copy(aggregateExpressions = newAliases(aggregateExpressions))) case oldVersion @ FlatMapGroupsInPandas(_, _, output, _) - if AttributeSet(output).intersect(conflictingAttributes).nonEmpty => + if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty => (oldVersion, oldVersion.copy(output = output.map(_.newInstance()))) case oldVersion: Generate