Skip to content

Commit 5bd8408

Browse files
authored
Merge pull request #1116 from cmu-delphi/ds/csv_importer_pandas
Refactor csv_importer.py with Pandas
2 parents 04eaada + 175c900 commit 5bd8408

File tree

2 files changed

+267
-344
lines changed

2 files changed

+267
-344
lines changed

src/acquisition/covidcast/csv_importer.py

Lines changed: 106 additions & 185 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
from datetime import date
1313
from glob import glob
1414
from logging import Logger
15-
from typing import Callable, Iterable, Iterator, NamedTuple, Optional, Tuple
15+
from typing import Callable, Iterable, List, NamedTuple, Optional, Tuple
1616

1717
import epiweeks as epi
18+
import numpy as np
1819
import pandas as pd
1920
from delphi.epidata.acquisition.covidcast.covidcast_row import CovidcastRow
2021
from delphi.epidata.acquisition.covidcast.database import Database, DBLoadStateException
@@ -61,6 +62,17 @@ def __init__(self, message, geo_id=None):
6162
self.message = message
6263
self.geo_id = geo_id
6364

65+
class GeoTypeSanityCheckException(ValueError):
66+
67+
def __init__(self, message, geo_type=None):
68+
self.message = message
69+
self.geo_type = geo_type
70+
71+
class ValueSanityCheckException(ValueError):
72+
73+
def __init__(self, message, value=None):
74+
self.message = message
75+
self.value = value
6476

6577

6678
class CsvImporter:
@@ -225,78 +237,7 @@ def is_header_valid(columns):
225237

226238

227239
@staticmethod
228-
def floaty_int(value: str) -> int:
229-
"""Cast a string to an int, even if it looks like a float.
230-
231-
For example, "-1" and "-1.0" should both result in -1. Non-integer floats
232-
will cause `ValueError` to be reaised.
233-
"""
234-
235-
float_value = float(value)
236-
if not float_value.is_integer():
237-
raise ValueError('not an int: "%s"' % str(value))
238-
return int(float_value)
239-
240-
241-
@staticmethod
242-
def maybe_apply(func, quantity):
243-
"""Apply the given function to the given quantity if not null-ish."""
244-
if str(quantity).lower() in ('inf', '-inf'):
245-
raise ValueError("Quantity given was an inf.")
246-
elif str(quantity).lower() in ('', 'na', 'nan', 'none'):
247-
return None
248-
else:
249-
return func(quantity)
250-
251-
252-
@staticmethod
253-
def validate_quantity(row, attr_quantity):
254-
"""Take a row and validate a given associated quantity (e.g., val, se, stderr).
255-
256-
Returns either a float, a None, or "Error".
257-
"""
258-
try:
259-
quantity = CsvImporter.maybe_apply(float, getattr(row, attr_quantity))
260-
return quantity
261-
except (ValueError, AttributeError):
262-
# val was a string or another data
263-
return "Error"
264-
265-
266-
@staticmethod
267-
def validate_missing_code(row, attr_quantity, attr_name, filepath=None, logger=None):
268-
"""Take a row and validate the missing code associated with
269-
a quantity (e.g., val, se, stderr).
270-
271-
Returns either a nan code for assignment to the missing quantity
272-
or a None to signal an error with the missing code. We decline
273-
to infer missing codes except for a very simple cases; the default
274-
is to produce an error so that the issue can be fixed in indicators.
275-
"""
276-
logger = get_structured_logger('load_csv') if logger is None else logger
277-
missing_entry = getattr(row, "missing_" + attr_name, None)
278-
279-
try:
280-
missing_entry = CsvImporter.floaty_int(missing_entry) # convert from string to float to int
281-
except (ValueError, TypeError):
282-
missing_entry = None
283-
284-
if missing_entry is None and attr_quantity is not None:
285-
return Nans.NOT_MISSING.value
286-
if missing_entry is None and attr_quantity is None:
287-
return Nans.OTHER.value
288-
289-
if missing_entry != Nans.NOT_MISSING.value and attr_quantity is not None:
290-
logger.warning(event = f"missing_{attr_name} column contradicting {attr_name} presence.", detail=(str(row)), file=filepath)
291-
return Nans.NOT_MISSING.value
292-
if missing_entry == Nans.NOT_MISSING.value and attr_quantity is None:
293-
logger.warning(event = f"missing_{attr_name} column contradicting {attr_name} presence.", detail=(str(row)), file=filepath)
294-
return Nans.OTHER.value
295-
296-
return missing_entry
297-
298-
@staticmethod
299-
def extract_and_check_row(geo_type, table):
240+
def extract_and_check_row(geo_type: str, table: pd.DataFrame, details: PathDetails) -> pd.DataFrame:
300241
"""Extract and return `CsvRowValue` from a CSV row, with sanity checks.
301242
302243
Also returns the name of the field which failed sanity check, or None.
@@ -305,56 +246,27 @@ def extract_and_check_row(geo_type, table):
305246
geo_type: the geographic resolution of the file
306247
"""
307248

308-
def check_county(geo_id):
309-
if len(geo_id) != 5 or not '01000' <= geo_id <= '80000':
310-
raise GeoIdSanityCheckException(f'len({geo_id}) != 5 or not "01000" <= {geo_id} <= "80000"', geo_id=geo_id)
311-
return geo_id
312-
313-
def check_hrr(geo_id):
314-
if not 1 <= int(geo_id) <= 500:
315-
raise GeoIdSanityCheckException(f'1 <= int({geo_id}) <= 500', geo_id=geo_id)
316-
return geo_id
317-
318-
def check_msa(geo_id):
319-
if len(geo_id) != 5 or not '10000' <= geo_id <= '99999':
320-
raise GeoIdSanityCheckException(f'len({geo_id}) != 5 or not "10000" <= {geo_id} <= "99999"', geo_id=geo_id)
321-
return geo_id
322-
323-
def check_dma(geo_id):
324-
if not 450 <= int(geo_id) <= 950:
325-
raise GeoIdSanityCheckException(f'not 450 <= int({geo_id}) <= 950', geo_id=geo_id)
326-
return geo_id
327-
328-
def check_state(geo_id):
329-
if len(geo_id) != 2 or not 'aa' <= geo_id <= 'zz':
330-
raise GeoIdSanityCheckException(f'len({geo_id}) != 2 or not "aa" <= {geo_id} <= "zz"', geo_id=geo_id)
331-
return geo_id
332-
333-
def check_hhs(geo_id):
334-
if not 1 <= int(geo_id) <= 10:
335-
raise GeoIdSanityCheckException(f'not 1 <= int({geo_id}) <= 10', geo_id=geo_id)
336-
return geo_id
337-
338-
def check_nation(geo_id):
339-
if len(geo_id) != 2 or not 'aa' <= geo_id <= 'zz':
340-
raise GeoIdSanityCheckException(f'len({geo_id}) != 2 or not "aa" <= {geo_id} <= "zz"', geo_id=geo_id)
341-
return geo_id
342-
343-
def validate_quantity(quantity):
344-
"""
345-
Take a row and validate a given associated quantity (e.g., val, se, stderr).
346-
Returns float, raise `GeoIdSanityCheckException`.
347-
"""
348-
if str(quantity).lower() in ('inf', '-inf'):
349-
raise GeoIdSanityCheckException("Quantity given was an inf.")
350-
elif str(quantity).lower() in ('', 'na', 'nan', 'none'):
351-
return None
352-
if quantity < 0:
353-
raise GeoIdSanityCheckException(f'{quantity} is not None and {quantity} < 0')
354-
else:
355-
return quantity
249+
def validate_geo_code(fail_mask: pd.Series, geo_type: str):
250+
validation_fails = table[fail_mask]
251+
if not validation_fails.empty:
252+
first_fail = validation_fails.iloc[0]
253+
raise GeoIdSanityCheckException(f'Invalid geo_id for {geo_type}', geo_id=first_fail["geo_id"])
254+
255+
def validate_quantity(column: pd.Series):
256+
"""Validate a column of a table using a validation function."""
257+
infinities = column[column.isin([float('inf'), float('-inf')])]
258+
if not infinities.empty:
259+
first_fail = infinities.iloc[0]
260+
raise ValueSanityCheckException(f'Invalid infinite value in {column.name}: {first_fail}', first_fail)
261+
262+
negative_values = column[column.lt(0)]
263+
if not negative_values.empty:
264+
first_fail = negative_values.iloc[0]
265+
raise ValueSanityCheckException(f'Invalid negative value in {column.name}: {first_fail}', first_fail)
266+
267+
return column
356268

357-
def validate_missing_code(data):
269+
def validate_missing_code(missing_code: pd.Series, column: pd.Series):
358270
"""Take a row and validate the missing code associated with
359271
a quantity (e.g., val, se, stderr).
360272
@@ -363,66 +275,63 @@ def validate_missing_code(data):
363275
to infer missing codes except for a very simple cases; the default
364276
is to produce an error so that the issue can be fixed in indicators.
365277
"""
366-
# logger = get_structured_logger('load_csv') if logger is None else logger
367-
368-
attr = data[0]
369-
missing_attr = data[1]
370-
# print(dir(data))Ц
371-
if missing_attr is pd.NA and attr is not pd.NA:
372-
return Nans.NOT_MISSING.value
373-
if missing_attr is pd.NA and attr is pd.NA:
374-
return Nans.OTHER.value
375-
if missing_attr != Nans.NOT_MISSING.value and attr is not pd.NA:
376-
# logger.warning(event = f"{data.index[1]} column contradicting {data.index[0]} presence.")
377-
return Nans.NOT_MISSING.value
378-
if missing_attr == Nans.NOT_MISSING.value and attr is pd.NA:
379-
# logger.warning(event = f"{data.index[1]} column contradicting {data.index[0]} presence.")
380-
return Nans.OTHER.value
381-
return missing_attr
278+
logger = get_structured_logger('validate_missing_code')
279+
280+
missing_code[missing_code.isna() & column.notna()] = Nans.NOT_MISSING.value
281+
missing_code[missing_code.isna() & column.isna()] = Nans.OTHER.value
282+
283+
contradict_mask = missing_code.ne(Nans.NOT_MISSING.value) & column.notna()
284+
if contradict_mask.any():
285+
first_fail = missing_code[contradict_mask].iloc[0]
286+
logger.warning(f'Correcting contradicting missing code: {first_fail} in {details.source}:{details.signal} {details.time_value} {details.geo_type}')
287+
missing_code[contradict_mask] = Nans.NOT_MISSING.value
288+
289+
contradict_mask = missing_code.eq(Nans.NOT_MISSING.value) & column.isna()
290+
if contradict_mask.any():
291+
first_fail = missing_code[contradict_mask].iloc[0]
292+
logger.warning(f'Correcting contradicting missing code: {first_fail} in {details.source}:{details.signal} {details.time_value} {details.geo_type}')
293+
missing_code[contradict_mask] = Nans.OTHER.value
294+
295+
return missing_code
382296

383297
# use consistent capitalization (e.g. for states)
384-
table['geo_id'] = table['geo_id'].map(lambda x: x.lower())
298+
table['geo_id'] = table['geo_id'].str.lower()
385299

386300
# sanity check geo_id with respect to geo_type
387301
if geo_type == 'county':
388-
table['geo_id'].apply(check_county)
389-
302+
fail_mask = (table['geo_id'].str.len() != 5) | ~table['geo_id'].between('01000', '80000')
390303
elif geo_type == 'hrr':
391-
table['geo_id'].apply(check_hrr)
392-
304+
fail_mask = ~table['geo_id'].astype(int).between(1, 500)
393305
elif geo_type == 'msa':
394-
table['geo_id'].apply(check_msa)
395-
306+
fail_mask = (table['geo_id'].str.len() != 5) | ~table['geo_id'].between('10000', '99999')
396307
elif geo_type == 'dma':
397-
table['geo_id'].apply(check_dma)
398-
308+
fail_mask = ~table['geo_id'].astype(int).between(450, 950)
399309
elif geo_type == 'state':
400-
# note that geo_id is lowercase
401-
table['geo_id'].apply(check_state)
402-
310+
fail_mask = (table['geo_id'].str.len() != 2) | ~table['geo_id'].between('aa', 'zz')
403311
elif geo_type == 'hhs':
404-
table['geo_id'].apply(check_hhs)
405-
312+
fail_mask = ~table['geo_id'].astype(int).between(1, 10)
406313
elif geo_type == 'nation':
407-
# geo_id is lowercase
408-
table['geo_id'].apply(check_nation)
314+
fail_mask = table['geo_id'] != 'us'
315+
else:
316+
raise GeoTypeSanityCheckException(f'Invalid geo_type: {geo_type}')
317+
318+
validate_geo_code(fail_mask, geo_type)
409319

410320
# Validate row values
411-
table['value'].apply(validate_quantity)
412-
table['stderr'].apply(validate_quantity)
413-
table['sample_size'].apply(validate_quantity)
321+
table['value'] = validate_quantity(table['value'])
322+
table['stderr'] = validate_quantity(table['stderr'])
323+
table['sample_size'] = validate_quantity(table['sample_size'])
414324

415-
# Validate and write missingness codes
416-
table['missing_value'] = table[['value', 'missing_value']].apply(validate_missing_code, axis=1)
417-
table['missing_stderr'] = table[['stderr', 'missing_stderr']].apply(validate_missing_code, axis=1)
418-
table['missing_sample_size'] = table[['sample_size', 'missing_sample_size']].apply(validate_missing_code, axis=1)
325+
# Validate and fix missingness codes
326+
table['missing_value'] = validate_missing_code(table['missing_value'], table['value'])
327+
table['missing_stderr'] = validate_missing_code(table['missing_stderr'], table['stderr'])
328+
table['missing_sample_size'] = validate_missing_code(table['missing_sample_size'], table['sample_size'])
419329

420-
# return validated table
421330
return table
422331

423332

424333
@staticmethod
425-
def load_csv(filepath: str, details: PathDetails) -> Iterator[Optional[CovidcastRow]]:
334+
def load_csv(filepath: str, details: PathDetails) -> Optional[List[CovidcastRow]]:
426335
"""Load, validate, and yield data as `RowValues` from a CSV file.
427336
428337
filepath: the CSV file to be loaded
@@ -435,44 +344,57 @@ def load_csv(filepath: str, details: PathDetails) -> Iterator[Optional[Covidcast
435344

436345
try:
437346
table = pd.read_csv(filepath, dtype=CsvImporter.DTYPES)
438-
except (ValueError, pd.errors.DtypeWarning) as e:
439-
logger.warning(event='Failed to open CSV with specified dtypes, switching to str', detail=str(e), file=filepath)
440-
table = pd.read_csv(filepath, dtype='str')
347+
except pd.errors.DtypeWarning as e:
348+
logger.warning(event='Failed to open CSV with specified dtypes', detail=str(e), file=filepath)
349+
return None
441350
except pd.errors.EmptyDataError as e:
442351
logger.warning(event='Empty data or header is encountered', detail=str(e), file=filepath)
443-
return
352+
return None
444353

445354
if not CsvImporter.is_header_valid(table.columns):
446355
logger.warning(event='invalid header', detail=table.columns, file=filepath)
447-
return
356+
return None
448357

449358
table.rename(columns={"val": "value", "se": "stderr", "missing_val": "missing_value", "missing_se": "missing_stderr"}, inplace=True)
359+
360+
for key in ["missing_value", "missing_stderr", "missing_sample_size"]:
361+
if key not in table.columns:
362+
table[key] = np.nan
450363

451364
try:
452-
table = CsvImporter.extract_and_check_row(details.geo_type, table)
365+
table = CsvImporter.extract_and_check_row(details.geo_type, table, details)
453366
except GeoIdSanityCheckException as err:
454-
row = table.loc[table['geo_id'] == err.geo_id]
455-
logger.warning(event='invalid value for row', detail=(row.to_csv(header=False, index=False, na_rep='NA')), file=filepath)
456-
return
457-
return[
367+
row = table.loc[table['geo_id'] == err.geo_id]
368+
logger.warning(event='invalid value for row', detail=(row.to_csv(header=False, index=False, na_rep='NA')), file=filepath)
369+
return None
370+
except GeoTypeSanityCheckException as err:
371+
logger.warning(event='invalid value for row', detail=err, file=filepath)
372+
return None
373+
except ValueSanityCheckException as err:
374+
row = table.loc[table['value'] == err.value]
375+
logger.warning(event='invalid value for row', detail=(row.to_csv(header=False, index=False, na_rep='NA')), file=filepath)
376+
return None
377+
except Exception as err:
378+
logger.warning(event='unknown error occured in extract_and_check_row', detail=err, file=filepath)
379+
return None
380+
return [
458381
CovidcastRow(
459382
source=details.source,
460383
signal=details.signal,
461384
time_type=details.time_type,
462385
geo_type=details.geo_type,
463386
time_value=details.time_value,
464387
geo_value=row.geo_id,
465-
value=row.value,
466-
stderr=row.stderr,
467-
sample_size=row.sample_size,
468-
missing_value = row.missing_value,
469-
missing_stderr = row.missing_stderr,
470-
missing_sample_size = row.missing_sample_size,
388+
value=row.value if pd.notna(row.value) else None,
389+
stderr=row.stderr if pd.notna(row.stderr) else None,
390+
sample_size=row.sample_size if pd.notna(row.sample_size) else None,
391+
missing_value=int(row.missing_value),
392+
missing_stderr=int(row.missing_stderr),
393+
missing_sample_size=int(row.missing_sample_size),
471394
issue=details.issue,
472395
lag=details.lag
473396
) for row in table.itertuples(index=False)
474397
]
475-
476398

477399

478400
def collect_files(data_dir: str, specific_issue_date: bool):
@@ -547,12 +469,11 @@ def upload_archive(
547469
archive_as_failed(path_src, filename, 'unknown',logger)
548470
continue
549471

472+
all_rows_valid = True
550473
csv_rows = CsvImporter.load_csv(path, details)
551-
rows_list = list(csv_rows)
552-
all_rows_valid = rows_list and all(r is not None for r in rows_list)
553-
if all_rows_valid:
474+
if csv_rows:
554475
try:
555-
modified_row_count = database.insert_or_update_bulk(rows_list)
476+
modified_row_count = database.insert_or_update_bulk(csv_rows)
556477
logger.info(f"insert_or_update_bulk {filename} returned {modified_row_count}")
557478
logger.info(
558479
"Inserted database rows",
@@ -577,7 +498,7 @@ def upload_archive(
577498
database.rollback()
578499

579500
# archive the current file based on validation results
580-
if all_rows_valid:
501+
if csv_rows and all_rows_valid:
581502
archive_as_successful(path_src, filename, details.source, logger)
582503
else:
583504
archive_as_failed(path_src, filename, details.source, logger)

0 commit comments

Comments
 (0)