Skip to content

Commit 0f22008

Browse files
committed
Limit correlations: update tests
* wrote test to ensure row limits work
1 parent 43419b8 commit 0f22008

File tree

1 file changed

+42
-11
lines changed

1 file changed

+42
-11
lines changed

tests/server/test_pandas.py

Lines changed: 42 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
"""Unit tests for pandas helper."""
22

33
# standard library
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, astuple
55
from typing import Any, Dict, Iterable
66
import unittest
77

8-
# from flask.testing import FlaskClient
8+
import pandas as pd
9+
from pandas.core.frame import DataFrame
10+
from sqlalchemy import text
911
import mysql.connector
10-
from delphi_utils import Nans
1112

13+
# from flask.testing import FlaskClient
14+
from delphi_utils import Nans
1215
from delphi.epidata.server._pandas import as_pandas
16+
from delphi.epidata.server._query import limit_query
1317

1418
# py3tester coverage target
1519
__test_target__ = "delphi.epidata.server._query"
@@ -26,14 +30,14 @@ class CovidcastRow:
2630
geo_value: str = "01234"
2731
value_updated_timestamp: int = 20200202
2832
value: float = 10.0
29-
stderr: float = 0
30-
sample_size: float = 10
33+
stderr: float = 0.
34+
sample_size: float = 10.
3135
direction_updated_timestamp: int = 20200202
3236
direction: int = 0
3337
issue: int = 20200202
3438
lag: int = 0
3539
is_latest_issue: bool = True
36-
is_wip: bool = False
40+
is_wip: bool = True
3741
missing_value: int = Nans.NOT_MISSING
3842
missing_stderr: int = Nans.NOT_MISSING
3943
missing_sample_size: int = Nans.NOT_MISSING
@@ -93,6 +97,14 @@ def geo_pair(self):
9397
def time_pair(self):
9498
return f"{self.time_type}:{self.time_value}"
9599

100+
@property
101+
def astuple(self):
102+
return astuple(self)[1:]
103+
104+
@property
105+
def aslist(self):
106+
return list(self.astuple)
107+
96108

97109
class UnitTests(unittest.TestCase):
98110
"""Basic unit tests."""
@@ -134,12 +146,31 @@ def _insert_rows(self, rows: Iterable[CovidcastRow]):
134146
self.cnx.commit()
135147
return rows
136148

149+
def _rows_to_df(self, rows: Iterable[CovidcastRow]) -> pd.DataFrame:
150+
columns = [
151+
'id', 'source', 'signal', 'time_type', 'geo_type', 'time_value',
152+
'geo_value', 'value_updated_timestamp', 'value', 'stderr',
153+
'sample_size', 'direction_updated_timestamp', 'direction', 'issue',
154+
'lag', 'is_latest_issue', 'is_wip', 'missing_value', 'missing_stderr',
155+
'missing_sample_size'
156+
]
157+
return pd.DataFrame.from_records([[i] + row.aslist for i, row in enumerate(rows, start=1)], columns=columns)
158+
137159
def test_as_pandas(self):
138-
rows = [CovidcastRow(time_value=20200401 + i, value=i) for i in range(10)]
160+
rows = [CovidcastRow(time_value=20200401 + i, value=float(i)) for i in range(10)]
139161
self._insert_rows(rows)
140162

141163
with self.subTest("simple"):
142-
query = "select * from `covidcast`"
143-
out = as_pandas(query, limit_rows=5)
144-
self.assertEqual(len(out["epidata"]), 5)
145-
164+
query = """select * from `covidcast`"""
165+
params = {}
166+
parse_dates = None
167+
engine = self.cnx
168+
df = pd.read_sql_query(str(query), engine, params=params, parse_dates=parse_dates)
169+
df = df.astype({"is_latest_issue": bool, "is_wip": bool})
170+
expected_df = self._rows_to_df(rows)
171+
pd.testing.assert_frame_equal(df, expected_df)
172+
query = limit_query(query, 5)
173+
df = pd.read_sql_query(str(query), engine, params=params, parse_dates=parse_dates)
174+
df = df.astype({"is_latest_issue": bool, "is_wip": bool})
175+
expected_df = self._rows_to_df(rows[:5])
176+
pd.testing.assert_frame_equal(df, expected_df)

0 commit comments

Comments
 (0)