Skip to content

Commit 2bf402b

Browse files
committed
[SPARK-24208][SQL][FOLLOWUP] Move test cases to proper locations
1 parent ebf4bfb commit 2bf402b

File tree

2 files changed

+34
-16
lines changed

2 files changed

+34
-16
lines changed

python/pyspark/sql/tests.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5471,6 +5471,22 @@ def foo(_):
54715471
self.assertEqual(r.a, 'hi')
54725472
self.assertEqual(r.b, 1)
54735473

5474+
def test_self_join_with_pandas(self):
5475+
import pyspark.sql.functions as F
5476+
5477+
@F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP)
5478+
def dummy_pandas_udf(df):
5479+
return df[['key', 'col']]
5480+
5481+
df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'),
5482+
Row(key=2, col='C')])
5483+
df_with_pandas = df.groupBy('key').apply(dummy_pandas_udf)
5484+
5485+
# this was throwing an AnalysisException before SPARK-24208
5486+
res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'),
5487+
F.col('temp0.key') == F.col('temp1.key'))
5488+
self.assertEquals(res.count(), 5)
5489+
54745490

54755491
@unittest.skipIf(
54765492
not _have_pandas or not _have_pyarrow,
@@ -5925,22 +5941,6 @@ def test_invalid_args(self):
59255941
'mixture.*aggregate function.*group aggregate pandas UDF'):
59265942
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
59275943

5928-
def test_self_join_with_pandas(self):
5929-
import pyspark.sql.functions as F
5930-
5931-
@F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP)
5932-
def dummy_pandas_udf(df):
5933-
return df[['key', 'col']]
5934-
5935-
df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'),
5936-
Row(key=2, col='C')])
5937-
dfWithPandas = df.groupBy('key').apply(dummy_pandas_udf)
5938-
5939-
# this was throwing an AnalysisException before SPARK-24208
5940-
res = dfWithPandas.alias('temp0').join(dfWithPandas.alias('temp1'),
5941-
F.col('temp0.key') == F.col('temp1.key'))
5942-
self.assertEquals(res.count(), 5)
5943-
59445944

59455945
@unittest.skipIf(
59465946
not _have_pandas or not _have_pyarrow,

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.util.TimeZone
2121

2222
import org.scalatest.Matchers
2323

24+
import org.apache.spark.api.python.PythonEvalType
2425
import org.apache.spark.sql.catalyst.TableIdentifier
2526
import org.apache.spark.sql.catalyst.dsl.expressions._
2627
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -557,4 +558,21 @@ class AnalysisSuite extends AnalysisTest with Matchers {
557558
SubqueryAlias("tbl", testRelation)))
558559
assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'"))
559560
}
561+
562+
test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") {
563+
val pythonUdf = PythonUDF("pyUDF", null,
564+
StructType(Seq(StructField("a", LongType))),
565+
Seq.empty,
566+
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
567+
true)
568+
val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes
569+
val project = Project(Seq(UnresolvedAttribute("a")), testRelation)
570+
val flatMapGroupsInPandas = FlatMapGroupsInPandas(
571+
Seq(UnresolvedAttribute("a")), pythonUdf, output, project)
572+
val left = SubqueryAlias("temp0", flatMapGroupsInPandas)
573+
val right = SubqueryAlias("temp1", flatMapGroupsInPandas)
574+
val join = Join(left, right, Inner, None)
575+
assertAnalysisSuccess(
576+
Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join))
577+
}
560578
}

0 commit comments

Comments
 (0)