@@ -5471,6 +5471,22 @@ def foo(_):
54715471 self .assertEqual (r .a , 'hi' )
54725472 self .assertEqual (r .b , 1 )
54735473
5474+ def test_self_join_with_pandas (self ):
5475+ import pyspark .sql .functions as F
5476+
5477+ @F .pandas_udf ('key long, col string' , F .PandasUDFType .GROUPED_MAP )
5478+ def dummy_pandas_udf (df ):
5479+ return df [['key' , 'col' ]]
5480+
5481+ df = self .spark .createDataFrame ([Row (key = 1 , col = 'A' ), Row (key = 1 , col = 'B' ),
5482+ Row (key = 2 , col = 'C' )])
5483+ df_with_pandas = df .groupBy ('key' ).apply (dummy_pandas_udf )
5484+
5485+ # this was throwing an AnalysisException before SPARK-24208
5486+ res = df_with_pandas .alias ('temp0' ).join (df_with_pandas .alias ('temp1' ),
5487+ F .col ('temp0.key' ) == F .col ('temp1.key' ))
5488+ self .assertEquals (res .count (), 5 )
5489+
54745490
54755491@unittest .skipIf (
54765492 not _have_pandas or not _have_pyarrow ,
@@ -5925,22 +5941,6 @@ def test_invalid_args(self):
59255941 'mixture.*aggregate function.*group aggregate pandas UDF' ):
59265942 df .groupby (df .id ).agg (mean_udf (df .v ), mean (df .v )).collect ()
59275943
5928- def test_self_join_with_pandas (self ):
5929- import pyspark .sql .functions as F
5930-
5931- @F .pandas_udf ('key long, col string' , F .PandasUDFType .GROUPED_MAP )
5932- def dummy_pandas_udf (df ):
5933- return df [['key' , 'col' ]]
5934-
5935- df = self .spark .createDataFrame ([Row (key = 1 , col = 'A' ), Row (key = 1 , col = 'B' ),
5936- Row (key = 2 , col = 'C' )])
5937- dfWithPandas = df .groupBy ('key' ).apply (dummy_pandas_udf )
5938-
5939- # this was throwing an AnalysisException before SPARK-24208
5940- res = dfWithPandas .alias ('temp0' ).join (dfWithPandas .alias ('temp1' ),
5941- F .col ('temp0.key' ) == F .col ('temp1.key' ))
5942- self .assertEquals (res .count (), 5 )
5943-
59445944
59455945@unittest .skipIf (
59465946 not _have_pandas or not _have_pyarrow ,
0 commit comments