Skip to content

Commit 87aeb92

Browse files
committed
more unit tests fro cogroup
1 parent d4cf6d0 commit 87aeb92

File tree

3 files changed

+68
-15
lines changed

3 files changed

+68
-15
lines changed

python/pyspark/sql/tests/test_pandas_udf_cogrouped_map.py

Lines changed: 66 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -69,24 +69,78 @@ def data2(self):
6969
.drop('ks')
7070

7171
def test_simple(self):
72-
import pandas as pd
72+
self._test_merge(self.data1, self.data2)
73+
74+
def test_left_group_empty(self):
75+
left = self.data1.where(col("id") % 2 == 0)
76+
self._test_merge(left, self.data2)
77+
78+
def test_right_group_empty(self):
79+
right = self.data2.where(col("id") % 2 == 0)
80+
self._test_merge(self.data1, right)
81+
82+
def test_different_schemas(self):
83+
right = self.data2.withColumn('v3', lit('a'))
84+
self._test_merge(self.data1, right, output_schema='id long, k int, v int, v2 int, v3 string')
85+
86+
def test_complex_group_by(self):
87+
left = pd.DataFrame.from_dict({
88+
'id': [1, 2, 3],
89+
'k': [5, 6, 7],
90+
'v': [9, 10, 11]
91+
})
92+
93+
right = pd.DataFrame.from_dict({
94+
'id': [11, 12, 13],
95+
'k': [5, 6, 7],
96+
'v2': [90, 100, 110]
97+
})
98+
99+
left_df = self.spark\
100+
.createDataFrame(left)\
101+
.groupby(col('id') % 2 == 0)
102+
103+
right_df = self.spark \
104+
.createDataFrame(right) \
105+
.groupby(col('id') % 2 == 0)
106+
107+
@pandas_udf('k long, v long, v2 long', PandasUDFType.COGROUPED_MAP)
108+
def merge_pandas(l, r):
109+
return pd.merge(l[['k', 'v']], r[['k', 'v2']], on=['k'])
110+
111+
result = left_df \
112+
.cogroup(right_df) \
113+
.apply(merge_pandas) \
114+
.sort(['k']) \
115+
.toPandas()
116+
117+
expected = pd.DataFrame.from_dict({
118+
'k': [5, 6, 7],
119+
'v': [9, 10, 11],
120+
'v2': [90, 100, 110]
121+
})
73122

74-
l = self.data1
75-
r = self.data2
123+
assert_frame_equal(expected, result, check_column_type=_check_column_type)
76124

77-
@pandas_udf('id long, k int, v int, v2 int', PandasUDFType.COGROUPED_MAP)
78-
def merge_pandas(left, right):
79-
return pd.merge(left, right, how='outer', on=['k', 'id'])
125+
def _test_merge(self, left, right, output_schema='id long, k int, v int, v2 int'):
80126

81-
result = l\
82-
.groupby('id')\
83-
.cogroup(r.groupby(r.id))\
127+
@pandas_udf(output_schema, PandasUDFType.COGROUPED_MAP)
128+
def merge_pandas(l, r):
129+
return pd.merge(l, r, on=['id', 'k'])
130+
131+
result = left \
132+
.groupby('id') \
133+
.cogroup(right.groupby('id')) \
84134
.apply(merge_pandas)\
85-
.sort(['id', 'k'])\
135+
.sort(['id', 'k']) \
86136
.toPandas()
87137

88-
expected = pd\
89-
.merge(l.toPandas(), r.toPandas(), how='outer', on=['k', 'id'])
138+
left = left.toPandas()
139+
right = right.toPandas()
140+
141+
expected = pd \
142+
.merge(left, right, on=['id', 'k']) \
143+
.sort_values(by=['id', 'k'])
90144

91145
assert_frame_equal(expected, result, check_column_type=_check_column_type)
92146

sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapCoGroupsInPandasExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.rdd.RDD
2222
import org.apache.spark.sql.catalyst.InternalRow
2323
import org.apache.spark.sql.catalyst.expressions._
2424
import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning}
25-
import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, GroupedIterator, SparkPlan}
25+
import org.apache.spark.sql.execution.{BinaryExecNode, CoGroupedIterator, SparkPlan}
2626

2727
case class FlatMapCoGroupsInPandasExec(
2828
leftGroup: Seq[Attribute],

sql/core/src/main/scala/org/apache/spark/sql/execution/python/InterleavedArrowPythonRunner.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,14 @@ import java.io._
2121
import java.net._
2222

2323
import org.apache.arrow.vector.VectorSchemaRoot
24-
import org.apache.arrow.vector.dictionary.DictionaryProvider
2524
import org.apache.arrow.vector.ipc.ArrowStreamWriter
25+
2626
import org.apache.spark._
2727
import org.apache.spark.api.python._
2828
import org.apache.spark.sql.catalyst.InternalRow
2929
import org.apache.spark.sql.execution.arrow.ArrowWriter
3030
import org.apache.spark.sql.types._
3131
import org.apache.spark.sql.util.ArrowUtils
32-
import org.apache.spark.sql.vectorized.ColumnarBatch
3332
import org.apache.spark.util.Utils
3433

3534

0 commit comments

Comments
 (0)