|
30 | 30 | ) |
31 | 31 | from .._query import QueryBuilder, execute_query, run_query, parse_row, filter_fields |
32 | 32 | from .._printer import create_printer, CSVPrinter |
| 33 | +from .._security import current_user, sources_protected_by_roles |
33 | 34 | from .._validate import require_all |
34 | 35 | from .._pandas import as_pandas, print_pandas |
35 | 36 | from .covidcast_utils import compute_trend, compute_trends, compute_trend_value, CovidcastMetaEntry |
36 | 37 | from ..utils import shift_day_value, day_to_time_value, time_value_to_iso, time_value_to_day, shift_week_value, time_value_to_week, guess_time_value_is_day, week_to_time_value, TimeValues |
37 | 38 | from .covidcast_utils.model import TimeType, count_signal_time_types, data_sources, create_source_signal_alias_mapper |
| 39 | +from delphi.epidata.common.logger import get_structured_logger |
38 | 40 |
|
39 | 41 | # first argument is the endpoint name |
40 | 42 | bp = Blueprint("covidcast", __name__) |
|
43 | 45 | latest_table = "epimetric_latest_v" |
44 | 46 | history_table = "epimetric_full_v" |
45 | 47 |
|
| 48 | +def restrict_by_roles(source_signal_sets): |
| 49 | + # takes a list of SourceSignalSet objects |
| 50 | + # and returns only those from the list |
| 51 | + # that the current user is permitted to access. |
| 52 | + user = current_user |
| 53 | + allowed_source_signal_sets = [] |
| 54 | + for src_sig_set in source_signal_sets: |
| 55 | + src = src_sig_set.source |
| 56 | + if src in sources_protected_by_roles: |
| 57 | + role = sources_protected_by_roles[src] |
| 58 | + if user and user.has_role(role): |
| 59 | + allowed_source_signal_sets.append(src_sig_set) |
| 60 | + else: |
| 61 | + # protected src and user does not have permission => leave it out of the srcsig sets |
| 62 | + get_structured_logger("covcast_endpt").warning("user requested restricted 'source'", api_key=(user and user.api_key), src=src) |
| 63 | + else: |
| 64 | + allowed_source_signal_sets.append(src_sig_set) |
| 65 | + return allowed_source_signal_sets |
| 66 | + |
| 67 | + |
46 | 68 | @bp.route("/", methods=("GET", "POST")) |
47 | 69 | def handle(): |
48 | 70 | source_signal_sets = parse_source_signal_sets() |
| 71 | + source_signal_sets = restrict_by_roles(source_signal_sets) |
49 | 72 | source_signal_sets, alias_mapper = create_source_signal_alias_mapper(source_signal_sets) |
50 | 73 | time_set = parse_time_set() |
51 | 74 | geo_sets = parse_geo_sets() |
@@ -102,6 +125,7 @@ def _verify_argument_time_type_matches(is_day_argument: bool, count_daily_signal |
102 | 125 | def handle_trend(): |
103 | 126 | require_all(request, "window", "date") |
104 | 127 | source_signal_sets = parse_source_signal_sets() |
| 128 | + source_signal_sets = restrict_by_roles(source_signal_sets) |
105 | 129 | daily_signals, weekly_signals = count_signal_time_types(source_signal_sets) |
106 | 130 | source_signal_sets, alias_mapper = create_source_signal_alias_mapper(source_signal_sets) |
107 | 131 | geo_sets = parse_geo_sets() |
@@ -157,6 +181,7 @@ def gen(rows): |
157 | 181 | def handle_trendseries(): |
158 | 182 | require_all(request, "window") |
159 | 183 | source_signal_sets = parse_source_signal_sets() |
| 184 | + source_signal_sets = restrict_by_roles(source_signal_sets) |
160 | 185 | daily_signals, weekly_signals = count_signal_time_types(source_signal_sets) |
161 | 186 | source_signal_sets, alias_mapper = create_source_signal_alias_mapper(source_signal_sets) |
162 | 187 | geo_sets = parse_geo_sets() |
@@ -405,8 +430,19 @@ def handle_meta(): |
405 | 430 | entry = by_signal.setdefault((row["data_source"], row["signal"]), []) |
406 | 431 | entry.append(row) |
407 | 432 |
|
| 433 | + user = current_user |
408 | 434 | sources: List[Dict[str, Any]] = [] |
409 | 435 | for source in data_sources: |
| 436 | + src = source.db_source # TODO: might wanna check source.source in addition to .db_source |
| 437 | + if src in sources_protected_by_roles: |
| 438 | + role = sources_protected_by_roles[src] |
| 439 | + if not (user and user.has_role(role)): |
| 440 | + # if this is a protected source |
| 441 | + # and the user doesnt have the allowed role |
| 442 | + # (or if we have no user) |
| 443 | + # then skip this source |
| 444 | + continue |
| 445 | + |
410 | 446 | meta_signals: List[Dict[str, Any]] = [] |
411 | 447 |
|
412 | 448 | for signal in source.signals: |
@@ -448,6 +484,7 @@ def handle_coverage(): |
448 | 484 | """ |
449 | 485 |
|
450 | 486 | source_signal_sets = parse_source_signal_sets() |
| 487 | + source_signal_sets = restrict_by_roles(source_signal_sets) |
451 | 488 | daily_signals, weekly_signals = count_signal_time_types(source_signal_sets) |
452 | 489 | source_signal_sets, alias_mapper = create_source_signal_alias_mapper(source_signal_sets) |
453 | 490 |
|
|
0 commit comments