11"""Unit tests for pandas helper."""
22
33# standard library
4- from dataclasses import dataclass
4+ from dataclasses import dataclass , astuple
55from typing import Any , Dict , Iterable
66import unittest
77
8- # from flask.testing import FlaskClient
8+ import pandas as pd
9+ from pandas .core .frame import DataFrame
10+ from sqlalchemy import text
911import mysql .connector
10- from delphi_utils import Nans
1112
13+ # from flask.testing import FlaskClient
14+ from delphi_utils import Nans
1215from delphi .epidata .server ._pandas import as_pandas
16+ from delphi .epidata .server ._query import limit_query
1317
1418# py3tester coverage target
1519__test_target__ = "delphi.epidata.server._query"
@@ -26,14 +30,14 @@ class CovidcastRow:
2630 geo_value : str = "01234"
2731 value_updated_timestamp : int = 20200202
2832 value : float = 10.0
29- stderr : float = 0
30- sample_size : float = 10
33+ stderr : float = 0.
34+ sample_size : float = 10.
3135 direction_updated_timestamp : int = 20200202
3236 direction : int = 0
3337 issue : int = 20200202
3438 lag : int = 0
3539 is_latest_issue : bool = True
36- is_wip : bool = False
40+ is_wip : bool = True
3741 missing_value : int = Nans .NOT_MISSING
3842 missing_stderr : int = Nans .NOT_MISSING
3943 missing_sample_size : int = Nans .NOT_MISSING
@@ -93,6 +97,14 @@ def geo_pair(self):
9397 def time_pair (self ):
9498 return f"{ self .time_type } :{ self .time_value } "
9599
100+ @property
101+ def astuple (self ):
102+ return astuple (self )[1 :]
103+
104+ @property
105+ def aslist (self ):
106+ return list (self .astuple )
107+
96108
97109class UnitTests (unittest .TestCase ):
98110 """Basic unit tests."""
@@ -134,12 +146,31 @@ def _insert_rows(self, rows: Iterable[CovidcastRow]):
134146 self .cnx .commit ()
135147 return rows
136148
149+ def _rows_to_df (self , rows : Iterable [CovidcastRow ]) -> pd .DataFrame :
150+ columns = [
151+ 'id' , 'source' , 'signal' , 'time_type' , 'geo_type' , 'time_value' ,
152+ 'geo_value' , 'value_updated_timestamp' , 'value' , 'stderr' ,
153+ 'sample_size' , 'direction_updated_timestamp' , 'direction' , 'issue' ,
154+ 'lag' , 'is_latest_issue' , 'is_wip' , 'missing_value' , 'missing_stderr' ,
155+ 'missing_sample_size'
156+ ]
157+ return pd .DataFrame .from_records ([[i ] + row .aslist for i , row in enumerate (rows , start = 1 )], columns = columns )
158+
137159 def test_as_pandas (self ):
138- rows = [CovidcastRow (time_value = 20200401 + i , value = i ) for i in range (10 )]
160+ rows = [CovidcastRow (time_value = 20200401 + i , value = float ( i ) ) for i in range (10 )]
139161 self ._insert_rows (rows )
140162
141163 with self .subTest ("simple" ):
142- query = "select * from `covidcast`"
143- out = as_pandas (query , limit_rows = 5 )
144- self .assertEqual (len (out ["epidata" ]), 5 )
145-
164+ query = """select * from `covidcast`"""
165+ params = {}
166+ parse_dates = None
167+ engine = self .cnx
168+ df = pd .read_sql_query (str (query ), engine , params = params , parse_dates = parse_dates )
169+ df = df .astype ({"is_latest_issue" : bool , "is_wip" : bool })
170+ expected_df = self ._rows_to_df (rows )
171+ pd .testing .assert_frame_equal (df , expected_df )
172+ query = limit_query (query , 5 )
173+ df = pd .read_sql_query (str (query ), engine , params = params , parse_dates = parse_dates )
174+ df = df .astype ({"is_latest_issue" : bool , "is_wip" : bool })
175+ expected_df = self ._rows_to_df (rows [:5 ])
176+ pd .testing .assert_frame_equal (df , expected_df )
0 commit comments