@@ -4763,17 +4763,6 @@ def test_vectorized_udf_invalid_length(self):
47634763 'Result vector from pandas_udf was not the required length' ):
47644764 df .select (raise_exception (col ('id' ))).collect ()
47654765
4766- def test_vectorized_udf_mix_udf (self ):
4767- from pyspark .sql .functions import pandas_udf , udf , col
4768- df = self .spark .range (10 )
4769- row_by_row_udf = udf (lambda x : x , LongType ())
4770- pd_udf = pandas_udf (lambda x : x , LongType ())
4771- with QuietTest (self .sc ):
4772- with self .assertRaisesRegexp (
4773- Exception ,
4774- 'Can not mix vectorized and non-vectorized UDFs' ):
4775- df .select (row_by_row_udf (col ('id' )), pd_udf (col ('id' ))).collect ()
4776-
47774766 def test_vectorized_udf_chained (self ):
47784767 from pyspark .sql .functions import pandas_udf , col
47794768 df = self .spark .range (10 )
@@ -5060,6 +5049,166 @@ def test_type_annotation(self):
50605049 df = self .spark .range (1 ).select (pandas_udf (f = _locals ['noop' ], returnType = 'bigint' )('id' ))
50615050 self .assertEqual (df .first ()[0 ], 0 )
50625051
5052+ def test_mixed_udf (self ):
5053+ import pandas as pd
5054+ from pyspark .sql .functions import col , udf , pandas_udf
5055+
5056+ df = self .spark .range (0 , 1 ).toDF ('v' )
5057+
5058+ # Test mixture of multiple UDFs and Pandas UDFs.
5059+
5060+ @udf ('int' )
5061+ def f1 (x ):
5062+ assert type (x ) == int
5063+ return x + 1
5064+
5065+ @pandas_udf ('int' )
5066+ def f2 (x ):
5067+ assert type (x ) == pd .Series
5068+ return x + 10
5069+
5070+ @udf ('int' )
5071+ def f3 (x ):
5072+ assert type (x ) == int
5073+ return x + 100
5074+
5075+ @pandas_udf ('int' )
5076+ def f4 (x ):
5077+ assert type (x ) == pd .Series
5078+ return x + 1000
5079+
5080+ # Test single expression with chained UDFs
5081+ df_chained_1 = df .withColumn ('f2_f1' , f2 (f1 (df ['v' ])))
5082+ df_chained_2 = df .withColumn ('f3_f2_f1' , f3 (f2 (f1 (df ['v' ]))))
5083+ df_chained_3 = df .withColumn ('f4_f3_f2_f1' , f4 (f3 (f2 (f1 (df ['v' ])))))
5084+ df_chained_4 = df .withColumn ('f4_f2_f1' , f4 (f2 (f1 (df ['v' ]))))
5085+ df_chained_5 = df .withColumn ('f4_f3_f1' , f4 (f3 (f1 (df ['v' ]))))
5086+
5087+ expected_chained_1 = df .withColumn ('f2_f1' , df ['v' ] + 11 )
5088+ expected_chained_2 = df .withColumn ('f3_f2_f1' , df ['v' ] + 111 )
5089+ expected_chained_3 = df .withColumn ('f4_f3_f2_f1' , df ['v' ] + 1111 )
5090+ expected_chained_4 = df .withColumn ('f4_f2_f1' , df ['v' ] + 1011 )
5091+ expected_chained_5 = df .withColumn ('f4_f3_f1' , df ['v' ] + 1101 )
5092+
5093+ self .assertEquals (expected_chained_1 .collect (), df_chained_1 .collect ())
5094+ self .assertEquals (expected_chained_2 .collect (), df_chained_2 .collect ())
5095+ self .assertEquals (expected_chained_3 .collect (), df_chained_3 .collect ())
5096+ self .assertEquals (expected_chained_4 .collect (), df_chained_4 .collect ())
5097+ self .assertEquals (expected_chained_5 .collect (), df_chained_5 .collect ())
5098+
5099+ # Test multiple mixed UDF expressions in a single projection
5100+ df_multi_1 = df \
5101+ .withColumn ('f1' , f1 (col ('v' ))) \
5102+ .withColumn ('f2' , f2 (col ('v' ))) \
5103+ .withColumn ('f3' , f3 (col ('v' ))) \
5104+ .withColumn ('f4' , f4 (col ('v' ))) \
5105+ .withColumn ('f2_f1' , f2 (col ('f1' ))) \
5106+ .withColumn ('f3_f1' , f3 (col ('f1' ))) \
5107+ .withColumn ('f4_f1' , f4 (col ('f1' ))) \
5108+ .withColumn ('f3_f2' , f3 (col ('f2' ))) \
5109+ .withColumn ('f4_f2' , f4 (col ('f2' ))) \
5110+ .withColumn ('f4_f3' , f4 (col ('f3' ))) \
5111+ .withColumn ('f3_f2_f1' , f3 (col ('f2_f1' ))) \
5112+ .withColumn ('f4_f2_f1' , f4 (col ('f2_f1' ))) \
5113+ .withColumn ('f4_f3_f1' , f4 (col ('f3_f1' ))) \
5114+ .withColumn ('f4_f3_f2' , f4 (col ('f3_f2' ))) \
5115+ .withColumn ('f4_f3_f2_f1' , f4 (col ('f3_f2_f1' )))
5116+
5117+ # Test mixed udfs in a single expression
5118+ df_multi_2 = df \
5119+ .withColumn ('f1' , f1 (col ('v' ))) \
5120+ .withColumn ('f2' , f2 (col ('v' ))) \
5121+ .withColumn ('f3' , f3 (col ('v' ))) \
5122+ .withColumn ('f4' , f4 (col ('v' ))) \
5123+ .withColumn ('f2_f1' , f2 (f1 (col ('v' )))) \
5124+ .withColumn ('f3_f1' , f3 (f1 (col ('v' )))) \
5125+ .withColumn ('f4_f1' , f4 (f1 (col ('v' )))) \
5126+ .withColumn ('f3_f2' , f3 (f2 (col ('v' )))) \
5127+ .withColumn ('f4_f2' , f4 (f2 (col ('v' )))) \
5128+ .withColumn ('f4_f3' , f4 (f3 (col ('v' )))) \
5129+ .withColumn ('f3_f2_f1' , f3 (f2 (f1 (col ('v' ))))) \
5130+ .withColumn ('f4_f2_f1' , f4 (f2 (f1 (col ('v' ))))) \
5131+ .withColumn ('f4_f3_f1' , f4 (f3 (f1 (col ('v' ))))) \
5132+ .withColumn ('f4_f3_f2' , f4 (f3 (f2 (col ('v' ))))) \
5133+ .withColumn ('f4_f3_f2_f1' , f4 (f3 (f2 (f1 (col ('v' ))))))
5134+
5135+ expected = df \
5136+ .withColumn ('f1' , df ['v' ] + 1 ) \
5137+ .withColumn ('f2' , df ['v' ] + 10 ) \
5138+ .withColumn ('f3' , df ['v' ] + 100 ) \
5139+ .withColumn ('f4' , df ['v' ] + 1000 ) \
5140+ .withColumn ('f2_f1' , df ['v' ] + 11 ) \
5141+ .withColumn ('f3_f1' , df ['v' ] + 101 ) \
5142+ .withColumn ('f4_f1' , df ['v' ] + 1001 ) \
5143+ .withColumn ('f3_f2' , df ['v' ] + 110 ) \
5144+ .withColumn ('f4_f2' , df ['v' ] + 1010 ) \
5145+ .withColumn ('f4_f3' , df ['v' ] + 1100 ) \
5146+ .withColumn ('f3_f2_f1' , df ['v' ] + 111 ) \
5147+ .withColumn ('f4_f2_f1' , df ['v' ] + 1011 ) \
5148+ .withColumn ('f4_f3_f1' , df ['v' ] + 1101 ) \
5149+ .withColumn ('f4_f3_f2' , df ['v' ] + 1110 ) \
5150+ .withColumn ('f4_f3_f2_f1' , df ['v' ] + 1111 )
5151+
5152+ self .assertEquals (expected .collect (), df_multi_1 .collect ())
5153+ self .assertEquals (expected .collect (), df_multi_2 .collect ())
5154+
5155+ def test_mixed_udf_and_sql (self ):
5156+ import pandas as pd
5157+ from pyspark .sql import Column
5158+ from pyspark .sql .functions import udf , pandas_udf
5159+
5160+ df = self .spark .range (0 , 1 ).toDF ('v' )
5161+
5162+ # Test mixture of UDFs, Pandas UDFs and SQL expression.
5163+
5164+ @udf ('int' )
5165+ def f1 (x ):
5166+ assert type (x ) == int
5167+ return x + 1
5168+
5169+ def f2 (x ):
5170+ assert type (x ) == Column
5171+ return x + 10
5172+
5173+ @pandas_udf ('int' )
5174+ def f3 (x ):
5175+ assert type (x ) == pd .Series
5176+ return x + 100
5177+
5178+ df1 = df .withColumn ('f1' , f1 (df ['v' ])) \
5179+ .withColumn ('f2' , f2 (df ['v' ])) \
5180+ .withColumn ('f3' , f3 (df ['v' ])) \
5181+ .withColumn ('f1_f2' , f1 (f2 (df ['v' ]))) \
5182+ .withColumn ('f1_f3' , f1 (f3 (df ['v' ]))) \
5183+ .withColumn ('f2_f1' , f2 (f1 (df ['v' ]))) \
5184+ .withColumn ('f2_f3' , f2 (f3 (df ['v' ]))) \
5185+ .withColumn ('f3_f1' , f3 (f1 (df ['v' ]))) \
5186+ .withColumn ('f3_f2' , f3 (f2 (df ['v' ]))) \
5187+ .withColumn ('f1_f2_f3' , f1 (f2 (f3 (df ['v' ])))) \
5188+ .withColumn ('f1_f3_f2' , f1 (f3 (f2 (df ['v' ])))) \
5189+ .withColumn ('f2_f1_f3' , f2 (f1 (f3 (df ['v' ])))) \
5190+ .withColumn ('f2_f3_f1' , f2 (f3 (f1 (df ['v' ])))) \
5191+ .withColumn ('f3_f1_f2' , f3 (f1 (f2 (df ['v' ])))) \
5192+ .withColumn ('f3_f2_f1' , f3 (f2 (f1 (df ['v' ]))))
5193+
5194+ expected = df .withColumn ('f1' , df ['v' ] + 1 ) \
5195+ .withColumn ('f2' , df ['v' ] + 10 ) \
5196+ .withColumn ('f3' , df ['v' ] + 100 ) \
5197+ .withColumn ('f1_f2' , df ['v' ] + 11 ) \
5198+ .withColumn ('f1_f3' , df ['v' ] + 101 ) \
5199+ .withColumn ('f2_f1' , df ['v' ] + 11 ) \
5200+ .withColumn ('f2_f3' , df ['v' ] + 110 ) \
5201+ .withColumn ('f3_f1' , df ['v' ] + 101 ) \
5202+ .withColumn ('f3_f2' , df ['v' ] + 110 ) \
5203+ .withColumn ('f1_f2_f3' , df ['v' ] + 111 ) \
5204+ .withColumn ('f1_f3_f2' , df ['v' ] + 111 ) \
5205+ .withColumn ('f2_f1_f3' , df ['v' ] + 111 ) \
5206+ .withColumn ('f2_f3_f1' , df ['v' ] + 111 ) \
5207+ .withColumn ('f3_f1_f2' , df ['v' ] + 111 ) \
5208+ .withColumn ('f3_f2_f1' , df ['v' ] + 111 )
5209+
5210+ self .assertEquals (expected .collect (), df1 .collect ())
5211+
50635212
50645213@unittest .skipIf (
50655214 not _have_pandas or not _have_pyarrow ,
@@ -5487,6 +5636,21 @@ def dummy_pandas_udf(df):
54875636 F .col ('temp0.key' ) == F .col ('temp1.key' ))
54885637 self .assertEquals (res .count (), 5 )
54895638
5639+ def test_mixed_scalar_udfs_followed_by_grouby_apply (self ):
5640+ import pandas as pd
5641+ from pyspark .sql .functions import udf , pandas_udf , PandasUDFType
5642+
5643+ df = self .spark .range (0 , 10 ).toDF ('v1' )
5644+ df = df .withColumn ('v2' , udf (lambda x : x + 1 , 'int' )(df ['v1' ])) \
5645+ .withColumn ('v3' , pandas_udf (lambda x : x + 2 , 'int' )(df ['v1' ]))
5646+
5647+ result = df .groupby () \
5648+ .apply (pandas_udf (lambda x : pd .DataFrame ([x .sum ().sum ()]),
5649+ 'sum int' ,
5650+ PandasUDFType .GROUPED_MAP ))
5651+
5652+ self .assertEquals (result .collect ()[0 ]['sum' ], 165 )
5653+
54905654
54915655@unittest .skipIf (
54925656 not _have_pandas or not _have_pyarrow ,
0 commit comments