Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 35 additions & 22 deletions python/pyspark/sql/tests/test_pandas_grouped_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,13 +545,13 @@ def f(pdf):

def test_grouped_over_window_with_key(self):

data = [(0, 1, "2018-03-10T00:00:00+00:00", False),
(1, 2, "2018-03-11T00:00:00+00:00", False),
(2, 2, "2018-03-12T00:00:00+00:00", False),
(3, 3, "2018-03-15T00:00:00+00:00", False),
(4, 3, "2018-03-16T00:00:00+00:00", False),
(5, 3, "2018-03-17T00:00:00+00:00", False),
(6, 3, "2018-03-21T00:00:00+00:00", False)]
data = [(0, 1, "2018-03-10T00:00:00+00:00", [0]),
(1, 2, "2018-03-11T00:00:00+00:00", [0]),
(2, 2, "2018-03-12T00:00:00+00:00", [0]),
(3, 3, "2018-03-15T00:00:00+00:00", [0]),
(4, 3, "2018-03-16T00:00:00+00:00", [0]),
(5, 3, "2018-03-17T00:00:00+00:00", [0]),
(6, 3, "2018-03-21T00:00:00+00:00", [0])]

expected_window = [
{'start': datetime.datetime(2018, 3, 10, 0, 0),
Expand All @@ -562,30 +562,43 @@ def test_grouped_over_window_with_key(self):
'end': datetime.datetime(2018, 3, 25, 0, 0)},
]

expected = {0: (1, expected_window[0]),
1: (2, expected_window[0]),
2: (2, expected_window[0]),
3: (3, expected_window[1]),
4: (3, expected_window[1]),
5: (3, expected_window[1]),
6: (3, expected_window[2])}
expected_key = {0: (1, expected_window[0]),
1: (2, expected_window[0]),
2: (2, expected_window[0]),
3: (3, expected_window[1]),
4: (3, expected_window[1]),
5: (3, expected_window[1]),
6: (3, expected_window[2])}

# id -> array of group with len of num records in window
expected = {0: [1],
1: [2, 2],
2: [2, 2],
3: [3, 3, 3],
4: [3, 3, 3],
5: [3, 3, 3],
6: [3]}

df = self.spark.createDataFrame(data, ['id', 'group', 'ts', 'result'])
df = df.select(col('id'), col('group'), col('ts').cast('timestamp'), col('result'))

@pandas_udf(df.schema, PandasUDFType.GROUPED_MAP)
def f(key, pdf):
group = key[0]
window_range = key[1]
# Result will be True if group and window range equal to expected
is_expected = pdf.id.apply(lambda id: (expected[id][0] == group and
expected[id][1] == window_range))
return pdf.assign(result=is_expected)

result = df.groupby('group', window('ts', '5 days')).apply(f).select('result').collect()
# Make sure the key with group and window values are correct
for _, i in pdf.id.iteritems():
assert expected_key[i][0] == group, "{} != {}".format(expected_key[i][0], group)
assert expected_key[i][1] == window_range, \
"{} != {}".format(expected_key[i][1], window_range)

# Check that all group and window_range values from udf matched expected
self.assertTrue(all([r[0] for r in result]))
return pdf.assign(result=[[group] * len(pdf)] * len(pdf))

result = df.groupby('group', window('ts', '5 days')).applyInPandas(f, df.schema)\
.select('id', 'result').collect()

for r in result:
self.assertListEqual(expected[r[0]], r[1])

def test_case_insensitive_grouping_column(self):
# SPARK-31915: case-insensitive grouping column should work.
Expand Down