|
1 | 1 | from datetime import date |
2 | 2 | from typing import ( |
| 3 | + Any, |
| 4 | + Dict, |
3 | 5 | Final, |
4 | 6 | List, |
5 | 7 | Mapping, |
|
9 | 11 | cast, |
10 | 12 | ) |
11 | 13 |
|
12 | | -from pandas import DataFrame |
| 14 | +from pandas import CategoricalDtype, DataFrame, Series |
13 | 15 | from requests import Response, Session |
14 | 16 | from requests.auth import HTTPBasicAuth |
15 | 17 | from tenacity import retry, stop_after_attempt |
|
21 | 23 | from ._model import ( |
22 | 24 | AEpiDataCall, |
23 | 25 | EpidataFieldInfo, |
| 26 | + EpidataFieldType, |
24 | 27 | EpiDataFormatType, |
25 | 28 | EpiDataResponse, |
26 | 29 | EpiRange, |
27 | 30 | EpiRangeParam, |
28 | 31 | OnlySupportsClassicFormatException, |
29 | 32 | add_endpoint_to_url, |
30 | 33 | ) |
| 34 | +from ._parse import fields_to_predicate |
31 | 35 |
|
32 | 36 | # Make the linter happy about the unused variables |
33 | 37 | __all__ = ["Epidata", "EpiDataCall", "EpiDataContext", "EpiRange", "CovidcastEpidata"] |
@@ -140,8 +144,36 @@ def df( |
140 | 144 | if self.only_supports_classic: |
141 | 145 | raise OnlySupportsClassicFormatException() |
142 | 146 | self._verify_parameters() |
143 | | - r = self.json(fields, disable_date_parsing=disable_date_parsing) |
144 | | - return self._as_df(r, fields, disable_date_parsing=disable_date_parsing) |
| 147 | + rows = self.json(fields, disable_date_parsing=disable_date_parsing) |
| 148 | + pred = fields_to_predicate(fields) |
| 149 | + columns: List[str] = [info.name for info in self.meta if pred(info.name)] |
| 150 | + df = DataFrame(rows, columns=columns or None) |
| 151 | + |
| 152 | + data_types: Dict[str, Any] = {} |
| 153 | + for info in self.meta: |
| 154 | + if not pred(info.name) or df[info.name].isnull().all(): |
| 155 | + continue |
| 156 | + if info.type == EpidataFieldType.bool: |
| 157 | + data_types[info.name] = bool |
| 158 | + elif info.type == EpidataFieldType.categorical: |
| 159 | + data_types[info.name] = CategoricalDtype( |
| 160 | + categories=Series(info.categories) if info.categories else None, ordered=True |
| 161 | + ) |
| 162 | + elif info.type == EpidataFieldType.int: |
| 163 | + data_types[info.name] = "Int64" |
| 164 | + elif info.type in ( |
| 165 | + EpidataFieldType.date, |
| 166 | + EpidataFieldType.epiweek, |
| 167 | + EpidataFieldType.date_or_epiweek, |
| 168 | + ): |
| 169 | + data_types[info.name] = "Int64" if disable_date_parsing else "datetime64[ns]" |
| 170 | + elif info.type == EpidataFieldType.float: |
| 171 | + data_types[info.name] = "Float64" |
| 172 | + else: |
| 173 | + data_types[info.name] = "string" |
| 174 | + if data_types: |
| 175 | + df = df.astype(data_types) |
| 176 | + return df |
145 | 177 |
|
146 | 178 |
|
147 | 179 | class EpiDataContext(AEpiDataEndpoints[EpiDataCall]): |
|
0 commit comments