@@ -69,24 +69,78 @@ def data2(self):
6969 .drop ('ks' )
7070
7171 def test_simple (self ):
72- import pandas as pd
72+ self ._test_merge (self .data1 , self .data2 )
73+
74+ def test_left_group_empty (self ):
75+ left = self .data1 .where (col ("id" ) % 2 == 0 )
76+ self ._test_merge (left , self .data2 )
77+
78+ def test_right_group_empty (self ):
79+ right = self .data2 .where (col ("id" ) % 2 == 0 )
80+ self ._test_merge (self .data1 , right )
81+
82+ def test_different_schemas (self ):
83+ right = self .data2 .withColumn ('v3' , lit ('a' ))
84+ self ._test_merge (self .data1 , right , output_schema = 'id long, k int, v int, v2 int, v3 string' )
85+
86+ def test_complex_group_by (self ):
87+ left = pd .DataFrame .from_dict ({
88+ 'id' : [1 , 2 , 3 ],
89+ 'k' : [5 , 6 , 7 ],
90+ 'v' : [9 , 10 , 11 ]
91+ })
92+
93+ right = pd .DataFrame .from_dict ({
94+ 'id' : [11 , 12 , 13 ],
95+ 'k' : [5 , 6 , 7 ],
96+ 'v2' : [90 , 100 , 110 ]
97+ })
98+
99+ left_df = self .spark \
100+ .createDataFrame (left )\
101+ .groupby (col ('id' ) % 2 == 0 )
102+
103+ right_df = self .spark \
104+ .createDataFrame (right ) \
105+ .groupby (col ('id' ) % 2 == 0 )
106+
107+ @pandas_udf ('k long, v long, v2 long' , PandasUDFType .COGROUPED_MAP )
108+ def merge_pandas (l , r ):
109+ return pd .merge (l [['k' , 'v' ]], r [['k' , 'v2' ]], on = ['k' ])
110+
111+ result = left_df \
112+ .cogroup (right_df ) \
113+ .apply (merge_pandas ) \
114+ .sort (['k' ]) \
115+ .toPandas ()
116+
117+ expected = pd .DataFrame .from_dict ({
118+ 'k' : [5 , 6 , 7 ],
119+ 'v' : [9 , 10 , 11 ],
120+ 'v2' : [90 , 100 , 110 ]
121+ })
73122
74- l = self .data1
75- r = self .data2
123+ assert_frame_equal (expected , result , check_column_type = _check_column_type )
76124
77- @pandas_udf ('id long, k int, v int, v2 int' , PandasUDFType .COGROUPED_MAP )
78- def merge_pandas (left , right ):
79- return pd .merge (left , right , how = 'outer' , on = ['k' , 'id' ])
125+ def _test_merge (self , left , right , output_schema = 'id long, k int, v int, v2 int' ):
80126
81- result = l \
82- .groupby ('id' )\
83- .cogroup (r .groupby (r .id ))\
127+ @pandas_udf (output_schema , PandasUDFType .COGROUPED_MAP )
128+ def merge_pandas (l , r ):
129+ return pd .merge (l , r , on = ['id' , 'k' ])
130+
131+ result = left \
132+ .groupby ('id' ) \
133+ .cogroup (right .groupby ('id' )) \
84134 .apply (merge_pandas )\
85- .sort (['id' , 'k' ])\
135+ .sort (['id' , 'k' ]) \
86136 .toPandas ()
87137
88- expected = pd \
89- .merge (l .toPandas (), r .toPandas (), how = 'outer' , on = ['k' , 'id' ])
138+ left = left .toPandas ()
139+ right = right .toPandas ()
140+
141+ expected = pd \
142+ .merge (left , right , on = ['id' , 'k' ]) \
143+ .sort_values (by = ['id' , 'k' ])
90144
91145 assert_frame_equal (expected , result , check_column_type = _check_column_type )
92146
0 commit comments