Skip to content

Commit a9dac7f

Browse files
authored
Merge pull request #704 from cmu-delphi/dshemetov/limit_correlations
Limit rows returned for correlations
2 parents add5b19 + dec134a commit a9dac7f

File tree

3 files changed

+181
-5
lines changed

3 files changed

+181
-5
lines changed

src/server/_pandas.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@
22
import pandas as pd
33

44
from sqlalchemy import text
5+
from sqlalchemy.engine.base import Engine
56

67
from ._common import engine
7-
from ._printer import create_printer, APrinter
8-
from ._query import filter_fields
8+
from ._config import MAX_RESULTS
9+
from ._printer import create_printer
10+
from ._query import filter_fields, limit_query
911
from ._exceptions import DatabaseErrorException
1012

1113

12-
def as_pandas(query: str, params: Dict[str, Any], parse_dates: Optional[Dict[str, str]] = None) -> pd.DataFrame:
14+
def as_pandas(query: str, params: Dict[str, Any], db_engine: Engine = engine, parse_dates: Optional[Dict[str, str]] = None, limit_rows = MAX_RESULTS+1) -> pd.DataFrame:
1315
try:
14-
return pd.read_sql_query(text(str(query)), engine, params=params, parse_dates=parse_dates)
16+
query = limit_query(query, limit_rows)
17+
return pd.read_sql_query(text(str(query)), db_engine, params=params, parse_dates=parse_dates)
1518
except Exception as e:
1619
raise DatabaseErrorException(str(e))
1720

src/server/_query.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,10 +232,15 @@ def parse_result(
232232
return [parse_row(row, fields_string, fields_int, fields_float) for row in db.execute(text(query), **params)]
233233

234234

235+
def limit_query(query: str, limit: int) -> str:
236+
full_query = f"{query} LIMIT {limit}"
237+
return full_query
238+
239+
235240
def run_query(p: APrinter, query_tuple: Tuple[str, Dict[str, Any]]):
236241
query, params = query_tuple
237242
# limit rows + 1 for detecting whether we would have more
238-
full_query = text(f"{query} LIMIT {p.remaining_rows + 1}")
243+
full_query = text(limit_query(query, p.remaining_rows + 1))
239244
app.logger.info("full_query: %s, params: %s", full_query, params)
240245
return db.execution_options(stream_results=True).execute(full_query, **params)
241246

tests/server/test_pandas.py

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
"""Unit tests for pandas helper."""
2+
3+
# standard library
4+
from dataclasses import dataclass, astuple
5+
from typing import Any, Dict, Iterable
6+
import unittest
7+
8+
import pandas as pd
9+
from sqlalchemy import create_engine
10+
11+
# from flask.testing import FlaskClient
12+
from delphi_utils import Nans
13+
from delphi.epidata.server.main import app
14+
from delphi.epidata.server._pandas import as_pandas
15+
from delphi.epidata.server._query import QueryBuilder
16+
17+
# py3tester coverage target
18+
__test_target__ = "delphi.epidata.server._query"
19+
20+
21+
@dataclass
22+
class CovidcastRow:
23+
id: int = 0
24+
source: str = "src"
25+
signal: str = "sig"
26+
time_type: str = "day"
27+
geo_type: str = "county"
28+
time_value: int = 20200411
29+
geo_value: str = "01234"
30+
value_updated_timestamp: int = 20200202
31+
value: float = 10.0
32+
stderr: float = 0.
33+
sample_size: float = 10.
34+
direction_updated_timestamp: int = 20200202
35+
direction: int = 0
36+
issue: int = 20200202
37+
lag: int = 0
38+
is_latest_issue: bool = True
39+
is_wip: bool = True
40+
missing_value: int = Nans.NOT_MISSING
41+
missing_stderr: int = Nans.NOT_MISSING
42+
missing_sample_size: int = Nans.NOT_MISSING
43+
44+
def __str__(self):
45+
return f"""(
46+
{self.id},
47+
'{self.source}',
48+
'{self.signal}',
49+
'{self.time_type}',
50+
'{self.geo_type}',
51+
{self.time_value},
52+
'{self.geo_value}',
53+
{self.value_updated_timestamp},
54+
{self.value},
55+
{self.stderr},
56+
{self.sample_size},
57+
{self.direction_updated_timestamp},
58+
{self.direction},
59+
{self.issue},
60+
{self.lag},
61+
{self.is_latest_issue},
62+
{self.is_wip},
63+
{self.missing_value},
64+
{self.missing_stderr},
65+
{self.missing_sample_size}
66+
)"""
67+
68+
@staticmethod
69+
def from_json(json: Dict[str, Any]) -> "CovidcastRow":
70+
return CovidcastRow(
71+
source=json["source"],
72+
signal=json["signal"],
73+
time_type=json["time_type"],
74+
geo_type=json["geo_type"],
75+
geo_value=json["geo_value"],
76+
direction=json["direction"],
77+
issue=json["issue"],
78+
lag=json["lag"],
79+
value=json["value"],
80+
stderr=json["stderr"],
81+
sample_size=json["sample_size"],
82+
missing_value=json["missing_value"],
83+
missing_stderr=json["missing_stderr"],
84+
missing_sample_size=json["missing_sample_size"],
85+
)
86+
87+
@property
88+
def signal_pair(self):
89+
return f"{self.source}:{self.signal}"
90+
91+
@property
92+
def geo_pair(self):
93+
return f"{self.geo_type}:{self.geo_value}"
94+
95+
@property
96+
def time_pair(self):
97+
return f"{self.time_type}:{self.time_value}"
98+
99+
@property
100+
def astuple(self):
101+
return astuple(self)[1:]
102+
103+
@property
104+
def aslist(self):
105+
return list(self.astuple)
106+
107+
108+
class UnitTests(unittest.TestCase):
109+
"""Basic unit tests."""
110+
111+
def setUp(self):
112+
"""Perform per-test setup."""
113+
app.config["TESTING"] = True
114+
app.config["WTF_CSRF_ENABLED"] = False
115+
app.config["DEBUG"] = False
116+
117+
# connect to the `epidata` database and clear the `covidcast` table
118+
engine = create_engine('mysql://user:pass@delphi_database_epidata/epidata')
119+
cnx = engine.connect()
120+
cnx.execute("truncate table covidcast")
121+
cnx.execute('update covidcast_meta_cache set timestamp = 0, epidata = ""')
122+
123+
# make connection and cursor available to test cases
124+
self.cnx = cnx
125+
126+
def tearDown(self):
127+
"""Perform per-test teardown."""
128+
self.cnx.close()
129+
130+
def _insert_rows(self, rows: Iterable[CovidcastRow]):
131+
sql = ",\n".join((str(r) for r in rows))
132+
self.cnx.execute(
133+
f"""
134+
INSERT INTO
135+
`covidcast` (`id`, `source`, `signal`, `time_type`, `geo_type`,
136+
`time_value`, `geo_value`, `value_updated_timestamp`,
137+
`value`, `stderr`, `sample_size`, `direction_updated_timestamp`,
138+
`direction`, `issue`, `lag`, `is_latest_issue`, `is_wip`,`missing_value`,
139+
`missing_stderr`,`missing_sample_size`)
140+
VALUES
141+
{sql}
142+
"""
143+
)
144+
return rows
145+
146+
def _rows_to_df(self, rows: Iterable[CovidcastRow]) -> pd.DataFrame:
147+
columns = [
148+
'id', 'source', 'signal', 'time_type', 'geo_type', 'time_value',
149+
'geo_value', 'value_updated_timestamp', 'value', 'stderr',
150+
'sample_size', 'direction_updated_timestamp', 'direction', 'issue',
151+
'lag', 'is_latest_issue', 'is_wip', 'missing_value', 'missing_stderr',
152+
'missing_sample_size'
153+
]
154+
return pd.DataFrame.from_records([[i] + row.aslist for i, row in enumerate(rows, start=1)], columns=columns)
155+
156+
def test_as_pandas(self):
157+
rows = [CovidcastRow(time_value=20200401 + i, value=float(i)) for i in range(10)]
158+
self._insert_rows(rows)
159+
160+
with app.test_request_context('/correlation'):
161+
q = QueryBuilder("covidcast", "t")
162+
163+
df = as_pandas(str(q), params={}, db_engine=self.cnx, parse_dates=None).astype({"is_latest_issue": bool, "is_wip": bool})
164+
expected_df = self._rows_to_df(rows)
165+
pd.testing.assert_frame_equal(df, expected_df)
166+
df = as_pandas(str(q), params={}, db_engine=self.cnx, parse_dates=None, limit_rows=5).astype({"is_latest_issue": bool, "is_wip": bool})
167+
expected_df = self._rows_to_df(rows[:5])
168+
pd.testing.assert_frame_equal(df, expected_df)

0 commit comments

Comments
 (0)