@@ -3395,9 +3395,16 @@ def assertFramesEqual(self, expected, result):
33953395 ("\n \n Result:\n %s\n %s" % (result , result .dtypes )))
33963396 self .assertTrue (expected .equals (result ), msg = msg )
33973397
3398- def test_groupby_apply (self ):
3398+ @property
3399+ def data (self ):
33993400 from pyspark .sql .functions import pandas_udf , array , explode , col , lit
3400- df = self .spark .range (10 ).toDF ('id' ).withColumn ("vs" , array ([lit (i ) for i in range (20 , 30 )])).withColumn ("v" , explode (col ('vs' ))).drop ('vs' )
3401+ return self .spark .range (10 ).toDF ('id' ) \
3402+ .withColumn ("vs" , array ([lit (i ) for i in range (20 , 30 )])) \
3403+ .withColumn ("v" , explode (col ('vs' ))).drop ('vs' )
3404+
3405+ def test_groupby_apply_simple (self ):
3406+ from pyspark .sql .functions import pandas_udf
3407+ df = self .data
34013408
34023409 def foo (df ):
34033410 ret = df
@@ -3417,6 +3424,26 @@ def foo(df):
34173424 expected = df .toPandas ().groupby ('id' ).apply (foo ).reset_index (drop = True )
34183425 self .assertFramesEqual (expected , result )
34193426
3427+ def test_groupby_apply_dtypes (self ):
3428+ from pyspark .sql .functions import pandas_udf
3429+ df = self .data
3430+
3431+ def foo (df ):
3432+ ret = df
3433+ ret = ret .assign (v3 = df .v * 5.0 + 1 )
3434+ return ret
3435+
3436+ sample_df = df .filter (df .id == 1 ).toPandas ()
3437+
3438+ foo_udf = pandas_udf (
3439+ foo ,
3440+ foo (sample_df ).dtypes
3441+ )
3442+
3443+ result = df .groupby ('id' ).apply (foo_udf ).sort ('id' ).toPandas ()
3444+ expected = df .toPandas ().groupby ('id' ).apply (foo ).reset_index (drop = True )
3445+ self .assertFramesEqual (expected , result )
3446+
34203447
34213448if __name__ == "__main__" :
34223449 from pyspark .sql .tests import *
0 commit comments