Skip to content

Commit c35bc9a

Browse files
committed
Incorporate George's review feedback
* make _combine_source_signal_pairs logic clearer * reorder covidcast.py endpoints under a single use_server_side_compute branch * rename _resolve_all_signals to _resolve_bool_source_signals * remove ugly one-liner from pad_time_pairs and replace with clearer code * add a few documentation updates * a few efficiency updates in smooth_diff
1 parent 967736f commit c35bc9a

File tree

5 files changed

+129
-117
lines changed

5 files changed

+129
-117
lines changed

src/server/_params.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from dataclasses import dataclass
2-
from itertools import groupby, chain
2+
from itertools import groupby
33
from math import inf
44
import re
55
from typing import List, Optional, Sequence, Tuple, Union
66

77
from flask import request
8+
from more_itertools import flatten
89

910
from ._exceptions import ValidationFailedException
1011
from .utils import days_in_range, weeks_in_range, guess_time_value_is_day
@@ -105,9 +106,9 @@ def _combine_source_signal_pairs(source_signal_pairs: List[SourceSignalPair]) ->
105106
for source, group in source_signal_pairs_grouped:
106107
group = list(group)
107108
if any(x.signal == True for x in group):
108-
source_signal_pairs_combined.append(SourceSignalPair(source, True))
109-
continue
110-
combined_signals = sorted(list(set(chain(*[x.signal for x in group]))))
109+
combined_signals = True
110+
else:
111+
combined_signals = sorted(set(flatten(x.signal for x in group)))
111112
source_signal_pairs_combined.append(SourceSignalPair(source, combined_signals))
112113
return source_signal_pairs_combined
113114

src/server/endpoints/covidcast.py

Lines changed: 106 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -179,37 +179,8 @@ def handle():
179179
lag = extract_integer("lag")
180180
is_time_type_week = any(time_pair.time_type == "week" for time_pair in time_pairs)
181181
is_time_value_true = any(isinstance(time_pair.time_values, bool) for time_pair in time_pairs)
182-
use_server_side_compute = not any((issues, lag, is_time_type_week, is_time_value_true)) and JIT_COMPUTE and not jit_bypass
183-
if use_server_side_compute:
184-
transform_args = parse_transform_args()
185-
pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length"))
186-
time_pairs = pad_time_pairs(time_pairs, pad_length)
187-
source_signal_pairs, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs, transform_args=transform_args)
188182

189-
# build query
190-
q = QueryBuilder(latest_table, "t")
191-
192-
fields_string = ["geo_type", "geo_value", "source", "signal", "time_type"]
193-
fields_int = ["time_value", "direction", "issue", "lag", "missing_value", "missing_stderr", "missing_sample_size"]
194-
fields_float = ["value", "stderr", "sample_size"]
195183
is_compatibility = is_compatibility_mode()
196-
197-
q.set_order("geo_type", "geo_value", "source", "signal", "time_type", "time_value", "issue")
198-
q.set_fields(fields_string, fields_int, fields_float)
199-
200-
# basic query info
201-
# data type of each field
202-
# build the source, signal, time, and location (type and id) filters
203-
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
204-
q.where_geo_pairs("geo_type", "geo_value", geo_pairs)
205-
q.where_time_pairs("time_type", "time_value", time_pairs)
206-
207-
q.index = guess_index_to_use(time_pairs, geo_pairs, issues, lag, as_of)
208-
209-
_handle_lag_issues_as_of(q, issues, lag, as_of)
210-
211-
p = create_printer()
212-
213184
def alias_row(row):
214185
if is_compatibility:
215186
# old api returned fewer fields
@@ -222,7 +193,20 @@ def alias_row(row):
222193
row["source"] = alias_mapper(row["source"], row["signal"])
223194
return row
224195

196+
# build query
197+
q = QueryBuilder(latest_table, "t")
198+
199+
fields_string = ["geo_type", "geo_value", "source", "signal", "time_type"]
200+
fields_int = ["time_value", "direction", "issue", "lag", "missing_value", "missing_stderr", "missing_sample_size"]
201+
fields_float = ["value", "stderr", "sample_size"]
202+
203+
use_server_side_compute = not any((issues, lag, is_time_type_week, is_time_value_true)) and JIT_COMPUTE and not jit_bypass
225204
if use_server_side_compute:
205+
transform_args = parse_transform_args()
206+
pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length"))
207+
time_pairs = pad_time_pairs(time_pairs, pad_length)
208+
source_signal_pairs, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs, transform_args=transform_args)
209+
226210
def gen_transform(rows):
227211
parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows)
228212
transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=time_pairs, transform_args=transform_args)
@@ -234,6 +218,22 @@ def gen_transform(rows):
234218
for row in parsed_rows:
235219
yield alias_row(row)
236220

221+
q.set_order("geo_type", "geo_value", "source", "signal", "time_type", "time_value", "issue")
222+
q.set_fields(fields_string, fields_int, fields_float)
223+
224+
# basic query info
225+
# data type of each field
226+
# build the source, signal, time, and location (type and id) filters
227+
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
228+
q.where_geo_pairs("geo_type", "geo_value", geo_pairs)
229+
q.where_time_pairs("time_type", "time_value", time_pairs)
230+
231+
q.index = guess_index_to_use(time_pairs, geo_pairs, issues, lag, as_of)
232+
233+
_handle_lag_issues_as_of(q, issues, lag, as_of)
234+
235+
p = create_printer()
236+
237237
# execute first query
238238
try:
239239
r = run_query(p, (str(q), q.params))
@@ -263,41 +263,40 @@ def handle_trend():
263263

264264
time_window, is_day = parse_day_or_week_range_arg("window")
265265
time_value, is_also_day = parse_day_or_week_arg("date")
266+
266267
if is_day != is_also_day:
267268
raise ValidationFailedException("mixing weeks with day arguments")
269+
268270
_verify_argument_time_type_matches(is_day, daily_signals, weekly_signals)
271+
269272
basis_time_value = extract_date("basis")
270273
if basis_time_value is None:
271274
base_shift = extract_integer("basis_shift")
272275
if base_shift is None:
273276
base_shift = 7
274277
basis_time_value = shift_time_value(time_value, -1 * base_shift) if is_day else shift_week_value(time_value, -1 * base_shift)
275278

276-
use_server_side_compute = not any((not is_day, not is_also_day)) and JIT_COMPUTE and not jit_bypass
277-
if use_server_side_compute:
278-
pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length"))
279-
source_signal_pairs, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs)
280-
time_window = pad_time_window(time_window, pad_length)
279+
def gen_trend(rows):
280+
for key, group in groupby(rows, lambda row: (row["geo_type"], row["geo_value"], row["source"], row["signal"])):
281+
geo_type, geo_value, source, signal = key
282+
if alias_mapper:
283+
source = alias_mapper(source, signal)
284+
trend = compute_trend(geo_type, geo_value, source, signal, time_value, basis_time_value, ((row["time_value"], row["value"]) for row in group))
285+
yield trend.asdict()
281286

282287
# build query
283288
q = QueryBuilder(latest_table, "t")
284289

285290
fields_string = ["geo_type", "geo_value", "source", "signal"]
286291
fields_int = ["time_value"]
287292
fields_float = ["value"]
288-
q.set_fields(fields_string, fields_int, fields_float)
289-
q.set_order("geo_type", "geo_value", "source", "signal", "time_value")
290-
291-
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
292-
q.where_geo_pairs("geo_type", "geo_value", geo_pairs)
293-
q.where_time_pairs("time_type", "time_value", [TimePair("day" if is_day else "week", [time_window])])
294-
295-
# fetch most recent issue fast
296-
_handle_lag_issues_as_of(q, None, None, None)
297-
298-
p = create_printer()
299293

294+
use_server_side_compute = all((is_day, is_also_day)) and JIT_COMPUTE and not jit_bypass
300295
if use_server_side_compute:
296+
pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length"))
297+
source_signal_pairs, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs)
298+
time_window = pad_time_window(time_window, pad_length)
299+
301300
def gen_transform(rows):
302301
parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows)
303302
transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=[TimePair("day", [time_window])], transform_args=transform_args)
@@ -309,13 +308,18 @@ def gen_transform(rows):
309308
for row in parsed_rows:
310309
yield row
311310

312-
def gen_trend(rows):
313-
for key, group in groupby(rows, lambda row: (row["geo_type"], row["geo_value"], row["source"], row["signal"])):
314-
geo_type, geo_value, source, signal = key
315-
if alias_mapper:
316-
source = alias_mapper(source, signal)
317-
trend = compute_trend(geo_type, geo_value, source, signal, time_value, basis_time_value, ((row["time_value"], row["value"]) for row in group))
318-
yield trend.asdict()
311+
q.set_fields(fields_string, fields_int, fields_float)
312+
q.set_order("geo_type", "geo_value", "source", "signal", "time_value")
313+
314+
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
315+
q.where_geo_pairs("geo_type", "geo_value", geo_pairs)
316+
q.where_time_pairs("time_type", "time_value", [TimePair("day" if is_day else "week", [time_window])])
317+
318+
# fetch most recent issue fast
319+
_handle_lag_issues_as_of(q, None, None, None)
320+
321+
p = create_printer()
322+
319323

320324
# execute first query
321325
try:
@@ -338,40 +342,39 @@ def handle_trendseries():
338342
jit_bypass = parse_jit_bypass()
339343

340344
time_window, is_day = parse_day_or_week_range_arg("window")
345+
341346
_verify_argument_time_type_matches(is_day, daily_signals, weekly_signals)
347+
342348
basis_shift = extract_integer(("basis", "basis_shift"))
343349
if basis_shift is None:
344350
basis_shift = 7
345351

346-
use_server_side_compute = is_day and JIT_COMPUTE and not jit_bypass
347-
if use_server_side_compute:
348-
pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length"))
349-
source_signal_pairs, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs)
350-
time_window = pad_time_window(time_window, pad_length)
352+
shifter = lambda x: shift_time_value(x, -basis_shift)
353+
if not is_day:
354+
shifter = lambda x: shift_week_value(x, -basis_shift)
355+
356+
def gen_trend(rows):
357+
for key, group in groupby(rows, lambda row: (row["geo_type"], row["geo_value"], row["source"], row["signal"])):
358+
geo_type, geo_value, source, signal = key
359+
if alias_mapper:
360+
source = alias_mapper(source, signal)
361+
trends = compute_trends(geo_type, geo_value, source, signal, shifter, ((row["time_value"], row["value"]) for row in group))
362+
for t in trends:
363+
yield t.asdict()
351364

352365
# build query
353366
q = QueryBuilder(latest_table, "t")
354367

355368
fields_string = ["geo_type", "geo_value", "source", "signal"]
356369
fields_int = ["time_value"]
357370
fields_float = ["value"]
358-
q.set_fields(fields_string, fields_int, fields_float)
359-
q.set_order("geo_type", "geo_value", "source", "signal", "time_value")
360-
361-
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
362-
q.where_geo_pairs("geo_type", "geo_value", geo_pairs)
363-
q.where_time_pairs("time_type", "time_value", [TimePair("day" if is_day else "week", [time_window])])
364-
365-
# fetch most recent issue fast
366-
_handle_lag_issues_as_of(q, None, None, None)
367-
368-
p = create_printer()
369-
370-
shifter = lambda x: shift_time_value(x, -basis_shift)
371-
if not is_day:
372-
shifter = lambda x: shift_week_value(x, -basis_shift)
373371

372+
use_server_side_compute = is_day and JIT_COMPUTE and not jit_bypass
374373
if use_server_side_compute:
374+
pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length"))
375+
source_signal_pairs, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs)
376+
time_window = pad_time_window(time_window, pad_length)
377+
375378
def gen_transform(rows):
376379
parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows)
377380
transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=[TimePair("day", [time_window])], transform_args=transform_args)
@@ -383,14 +386,17 @@ def gen_transform(rows):
383386
for row in parsed_rows:
384387
yield row
385388

386-
def gen_trend(rows):
387-
for key, group in groupby(rows, lambda row: (row["geo_type"], row["geo_value"], row["source"], row["signal"])):
388-
geo_type, geo_value, source, signal = key
389-
if alias_mapper:
390-
source = alias_mapper(source, signal)
391-
trends = compute_trends(geo_type, geo_value, source, signal, shifter, ((row["time_value"], row["value"]) for row in group))
392-
for t in trends:
393-
yield t.asdict()
389+
q.set_fields(fields_string, fields_int, fields_float)
390+
q.set_order("geo_type", "geo_value", "source", "signal", "time_value")
391+
392+
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
393+
q.where_geo_pairs("geo_type", "geo_value", geo_pairs)
394+
q.where_time_pairs("time_type", "time_value", [TimePair("day" if is_day else "week", [time_window])])
395+
396+
# fetch most recent issue fast
397+
_handle_lag_issues_as_of(q, None, None, None)
398+
399+
p = create_printer()
394400

395401
# execute first query
396402
try:
@@ -483,9 +489,12 @@ def handle_export():
483489
start_day, is_day = parse_day_or_week_arg("start_day", 202001 if weekly_signals > 0 else 20200401)
484490
end_day, is_end_day = parse_day_or_week_arg("end_day", 202020 if weekly_signals > 0 else 20200901)
485491
time_window = (start_day, end_day)
492+
486493
if is_day != is_end_day:
487494
raise ValidationFailedException("mixing weeks with day arguments")
495+
488496
_verify_argument_time_type_matches(is_day, daily_signals, weekly_signals)
497+
489498
transform_args = parse_transform_args()
490499
jit_bypass = parse_jit_bypass()
491500

@@ -499,19 +508,30 @@ def handle_export():
499508
if is_day != is_as_of_day:
500509
raise ValidationFailedException("mixing weeks with day arguments")
501510

511+
# build query
512+
q = QueryBuilder(latest_table, "t")
513+
514+
fields_string = ["geo_value", "signal", "geo_type", "source"]
515+
fields_int = ["time_value", "issue", "lag"]
516+
fields_float = ["value", "stderr", "sample_size"]
517+
502518
use_server_side_compute = all([is_day, is_end_day]) and JIT_COMPUTE and not jit_bypass
503519
if use_server_side_compute:
504520
pad_length = get_pad_length(source_signal_pairs, transform_args.get("smoother_window_length"))
505521
source_signal_pairs, row_transform_generator = get_basename_signal_and_jit_generator(source_signal_pairs)
506522
time_window = pad_time_window(time_window, pad_length)
507523

508-
# build query
509-
q = QueryBuilder(latest_table, "t")
524+
def gen_transform(rows):
525+
parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows)
526+
transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=[TimePair("day", [time_window])], transform_args=transform_args)
527+
for row in transformed_rows:
528+
yield row
529+
else:
530+
def gen_transform(rows):
531+
for row in rows:
532+
yield row
510533

511-
fields_string = ["geo_value", "signal", "geo_type", "source"]
512-
fields_int = ["time_value", "issue", "lag"]
513-
fields_float = ["value", "stderr", "sample_size"]
514-
q.set_fields(fields_string + fields_int + fields_float, [], [])
534+
q.set_fields(fields_string, fields_int, fields_float)
515535
q.set_order("time_value", "geo_value")
516536
q.where_source_signal_pairs("source", "signal", source_signal_pairs)
517537
q.where_time_pairs("time_type", "time_value", [TimePair("day" if is_day else "week", [time_window])])
@@ -541,17 +561,6 @@ def parse_csv_row(i, row):
541561
"data_source": alias_mapper(row["source"], row["signal"]) if alias_mapper else row["source"],
542562
}
543563

544-
if use_server_side_compute:
545-
def gen_transform(rows):
546-
parsed_rows = (parse_row(row, fields_string, fields_int, fields_float) for row in rows)
547-
transformed_rows = row_transform_generator(parsed_rows=parsed_rows, time_pairs=[TimePair("day", [time_window])], transform_args=transform_args)
548-
for row in transformed_rows:
549-
yield row
550-
else:
551-
def gen_transform(rows):
552-
for row in rows:
553-
yield row
554-
555564
def gen_parse(rows):
556565
for i, row in enumerate(rows):
557566
yield parse_csv_row(i, row)

0 commit comments

Comments
 (0)