Skip to content

Commit 1138489

Browse files
mgaido91gatorsmile
authored andcommitted
[SPARK-24208][SQL][FOLLOWUP] Move test cases to proper locations
## What changes were proposed in this pull request? The PR is a followup to move the test cases introduced by the original PR in their proper location. ## How was this patch tested? moved UTs Author: Marco Gaido <[email protected]> Closes apache#21751 from mgaido91/SPARK-24208_followup.
1 parent 07704c9 commit 1138489

File tree

3 files changed

+34
-28
lines changed

3 files changed

+34
-28
lines changed

python/pyspark/sql/tests.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5471,6 +5471,22 @@ def foo(_):
54715471
self.assertEqual(r.a, 'hi')
54725472
self.assertEqual(r.b, 1)
54735473

5474+
def test_self_join_with_pandas(self):
5475+
import pyspark.sql.functions as F
5476+
5477+
@F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP)
5478+
def dummy_pandas_udf(df):
5479+
return df[['key', 'col']]
5480+
5481+
df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'),
5482+
Row(key=2, col='C')])
5483+
df_with_pandas = df.groupBy('key').apply(dummy_pandas_udf)
5484+
5485+
# this was throwing an AnalysisException before SPARK-24208
5486+
res = df_with_pandas.alias('temp0').join(df_with_pandas.alias('temp1'),
5487+
F.col('temp0.key') == F.col('temp1.key'))
5488+
self.assertEquals(res.count(), 5)
5489+
54745490

54755491
@unittest.skipIf(
54765492
not _have_pandas or not _have_pyarrow,
@@ -5925,22 +5941,6 @@ def test_invalid_args(self):
59255941
'mixture.*aggregate function.*group aggregate pandas UDF'):
59265942
df.groupby(df.id).agg(mean_udf(df.v), mean(df.v)).collect()
59275943

5928-
def test_self_join_with_pandas(self):
5929-
import pyspark.sql.functions as F
5930-
5931-
@F.pandas_udf('key long, col string', F.PandasUDFType.GROUPED_MAP)
5932-
def dummy_pandas_udf(df):
5933-
return df[['key', 'col']]
5934-
5935-
df = self.spark.createDataFrame([Row(key=1, col='A'), Row(key=1, col='B'),
5936-
Row(key=2, col='C')])
5937-
dfWithPandas = df.groupBy('key').apply(dummy_pandas_udf)
5938-
5939-
# this was throwing an AnalysisException before SPARK-24208
5940-
res = dfWithPandas.alias('temp0').join(dfWithPandas.alias('temp1'),
5941-
F.col('temp0.key') == F.col('temp1.key'))
5942-
self.assertEquals(res.count(), 5)
5943-
59445944

59455945
@unittest.skipIf(
59465946
not _have_pandas or not _have_pyarrow,

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.util.TimeZone
2121

2222
import org.scalatest.Matchers
2323

24+
import org.apache.spark.api.python.PythonEvalType
2425
import org.apache.spark.sql.catalyst.TableIdentifier
2526
import org.apache.spark.sql.catalyst.dsl.expressions._
2627
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -557,4 +558,21 @@ class AnalysisSuite extends AnalysisTest with Matchers {
557558
SubqueryAlias("tbl", testRelation)))
558559
assertAnalysisError(barrier, Seq("cannot resolve '`tbl.b`'"))
559560
}
561+
562+
test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") {
563+
val pythonUdf = PythonUDF("pyUDF", null,
564+
StructType(Seq(StructField("a", LongType))),
565+
Seq.empty,
566+
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
567+
true)
568+
val output = pythonUdf.dataType.asInstanceOf[StructType].toAttributes
569+
val project = Project(Seq(UnresolvedAttribute("a")), testRelation)
570+
val flatMapGroupsInPandas = FlatMapGroupsInPandas(
571+
Seq(UnresolvedAttribute("a")), pythonUdf, output, project)
572+
val left = SubqueryAlias("temp0", flatMapGroupsInPandas)
573+
val right = SubqueryAlias("temp1", flatMapGroupsInPandas)
574+
val join = Join(left, right, Inner, None)
575+
assertAnalysisSuccess(
576+
Project(Seq(UnresolvedAttribute("temp0.a"), UnresolvedAttribute("temp1.a")), join))
577+
}
560578
}

sql/core/src/test/scala/org/apache/spark/sql/GroupedDatasetSuite.scala

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -93,16 +93,4 @@ class GroupedDatasetSuite extends QueryTest with SharedSQLContext {
9393
}
9494
datasetWithUDF.unpersist(true)
9595
}
96-
97-
test("SPARK-24208: analysis fails on self-join with FlatMapGroupsInPandas") {
98-
val df = datasetWithUDF.groupBy("s").flatMapGroupsInPandas(PythonUDF(
99-
"pyUDF",
100-
null,
101-
StructType(Seq(StructField("s", LongType))),
102-
Seq.empty,
103-
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
104-
true))
105-
val df1 = df.alias("temp0").join(df.alias("temp1"), $"temp0.s" === $"temp1.s")
106-
df1.queryExecution.assertAnalyzed()
107-
}
10896
}

0 commit comments

Comments
 (0)