diff --git a/.dockerignore b/.dockerignore index 058bb7939..820d11c43 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,6 +1,7 @@ /delphi-epidata -/.mypy_cache +**/.mypy_cache /.github /docs -__pycache__ -/node_modules \ No newline at end of file +**/__pycache__ +**/.pytest_cache +**/node_modules \ No newline at end of file diff --git a/.env.example b/.env.example index 212706e08..501a8f216 100644 --- a/.env.example +++ b/.env.example @@ -1,18 +1,8 @@ FLASK_DEBUG=True SQLALCHEMY_DATABASE_URI=sqlite:///test.db FLASK_SECRET=abc -SECRET_TWITTER=abc -SECRET_GHT=abc -SECRET_FLUVIEW=abc -SECRET_CDC=abc -SECRET_SENSORS=abc -SECRET_SENSOR_TWTR=abc -SECRET_SENSOR_GFT=abc -SECRET_SENSOR_GHT=abc -SECRET_SENSOR_GHTJ=abc -SECRET_SENSOR_CDC=abc -SECRET_SENSOR_QUID=abc -SECRET_SENSOR_WIKI=abc -SECRET_QUIDEL=abc -SECRET_NOROSTAT=abc -SECRET_AFHSB=abc \ No newline at end of file +#API_REQUIRED_STARTING_AT=2021-07-30 +API_KEY_ADMIN_PASSWORD=abc +API_KEY_REGISTER_WEBHOOK_TOKEN=abc +RECAPTCHA_SITE_KEY +RECAPTCHA_SECRET_KEY \ No newline at end of file diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 7a155e6f9..09df27542 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -53,6 +53,7 @@ jobs: run: | docker build -t delphi_database_epidata -f ./repos/delphi/delphi-epidata/dev/docker/database/epidata/Dockerfile . docker build -t delphi_web_python -f repos/delphi/delphi-epidata/dev/docker/python/Dockerfile . + sudo docker build -t delphi_redis_instance -f repos/delphi/delphi-epidata/dev/docker/redis/Dockerfile . cd ./repos/delphi/delphi-epidata docker build -t delphi_web_epidata -f ./devops/Dockerfile . cd ../../../ @@ -63,7 +64,8 @@ jobs: run: | docker network create --driver bridge delphi-net docker run --rm -d -p 13306:3306 --network delphi-net --name delphi_database_epidata --cap-add=sys_nice delphi_database_epidata - docker run --rm -d -p 10080:80 --env "MODULE_NAME=delphi.epidata.server.main" --env "SQLALCHEMY_DATABASE_URI=mysql+mysqldb://user:pass@delphi_database_epidata:3306/epidata" --env "FLASK_SECRET=abc" --env "FLASK_PREFIX=/epidata" --network delphi-net --name delphi_web_epidata delphi_web_epidata + docker run --rm -d -p 10080:80 --env "MODULE_NAME=delphi.epidata.server.main" --env "SQLALCHEMY_DATABASE_URI=mysql+mysqldb://user:pass@delphi_database_epidata:3306/epidata" --env "FLASK_SECRET=abc" --env "FLASK_PREFIX=/epidata" --env "RATELIMIT_STORAGE_URL=redis://delphi_redis_instance:6379" --env "API_KEY_REGISTER_WEBHOOK_TOKEN=abc" --env "API_KEY_ADMIN_PASSWORD=test_admin_password" --network delphi-net --name delphi_web_epidata delphi_web_epidata + docker run --rm -p 6379:6379 --network delphi-net --name delphi_redis_instance delphi_redis_instance docker ps - run: | diff --git a/dev/docker/python/Dockerfile b/dev/docker/python/Dockerfile index ffce16b0f..9d6b262e9 100644 --- a/dev/docker/python/Dockerfile +++ b/dev/docker/python/Dockerfile @@ -5,6 +5,7 @@ WORKDIR /usr/src/app COPY repos repos COPY repos/delphi/delphi-epidata/dev/docker/python/setup.sh . + RUN ln -s -f /usr/share/zoneinfo/America/New_York /etc/localtime && \ chmod -R o+r repos/ && \ bash setup.sh && \ diff --git a/dev/docker/redis/Dockerfile b/dev/docker/redis/Dockerfile new file mode 100644 index 000000000..ae972e65f --- /dev/null +++ b/dev/docker/redis/Dockerfile @@ -0,0 +1,3 @@ +FROM redis + +CMD ["redis-server"] \ No newline at end of file diff --git a/devops/Dockerfile b/devops/Dockerfile index 97dc0e2c8..e5f20e9aa 100644 --- a/devops/Dockerfile +++ b/devops/Dockerfile @@ -23,7 +23,6 @@ RUN ln -s -f /usr/share/zoneinfo/America/New_York /etc/localtime \ # the file /tmp/requirements.txt is created in the parent docker definition. (see: # https://github.com/tiangolo/meinheld-gunicorn-docker/blob/master/docker-images/python3.8.dockerfile#L5 ) # this combined requirements installation ensures all version constrants are accounted for. - # disable python stdout buffering ENV PYTHONUNBUFFERED 1 diff --git a/integrations/acquisition/covid_hosp/facility/test_scenarios.py b/integrations/acquisition/covid_hosp/facility/test_scenarios.py index aaa3c5e3b..775fe6b8b 100644 --- a/integrations/acquisition/covid_hosp/facility/test_scenarios.py +++ b/integrations/acquisition/covid_hosp/facility/test_scenarios.py @@ -29,6 +29,7 @@ def setUp(self): # use the local instance of the Epidata API Epidata.BASE_URL = 'http://delphi_web_epidata/epidata/api.php' + Epidata.auth = ('epidata', 'key') # use the local instance of the epidata database secrets.db.host = 'delphi_database_epidata' @@ -40,6 +41,8 @@ def setUp(self): cur.execute('truncate table covid_hosp_facility') cur.execute('truncate table covid_hosp_facility_key') cur.execute('truncate table covid_hosp_meta') + cur.execute('delete from api_user') + cur.execute('insert into api_user(api_key, tracking, registered) values ("key", 1, 1)') @freeze_time("2021-03-16") def test_acquire_dataset(self): diff --git a/integrations/acquisition/covid_hosp/state_daily/test_scenarios.py b/integrations/acquisition/covid_hosp/state_daily/test_scenarios.py index e55bc8ca6..424b3019a 100644 --- a/integrations/acquisition/covid_hosp/state_daily/test_scenarios.py +++ b/integrations/acquisition/covid_hosp/state_daily/test_scenarios.py @@ -33,6 +33,7 @@ def setUp(self): # use the local instance of the Epidata API Epidata.BASE_URL = 'http://delphi_web_epidata/epidata/api.php' + Epidata.auth = ('epidata', 'key') # use the local instance of the epidata database secrets.db.host = 'delphi_database_epidata' @@ -43,6 +44,8 @@ def setUp(self): with db.new_cursor() as cur: cur.execute('truncate table covid_hosp_state_timeseries') cur.execute('truncate table covid_hosp_meta') + cur.execute('delete from api_user') + cur.execute('insert into api_user(api_key, tracking, registered) values("key", 1, 1)') @freeze_time("2021-03-16") def test_acquire_dataset(self): diff --git a/integrations/acquisition/covid_hosp/state_timeseries/test_scenarios.py b/integrations/acquisition/covid_hosp/state_timeseries/test_scenarios.py index 5d13ccbb0..1384927e8 100644 --- a/integrations/acquisition/covid_hosp/state_timeseries/test_scenarios.py +++ b/integrations/acquisition/covid_hosp/state_timeseries/test_scenarios.py @@ -29,6 +29,7 @@ def setUp(self): # use the local instance of the Epidata API Epidata.BASE_URL = 'http://delphi_web_epidata/epidata/api.php' + Epidata.auth = ('epidata', 'key') # use the local instance of the epidata database secrets.db.host = 'delphi_database_epidata' @@ -39,6 +40,8 @@ def setUp(self): with db.new_cursor() as cur: cur.execute('truncate table covid_hosp_state_timeseries') cur.execute('truncate table covid_hosp_meta') + cur.execute('delete from api_user') + cur.execute('insert into api_user(api_key, tracking, registered) values ("key", 1, 1)') @freeze_time("2021-03-17") def test_acquire_dataset(self): diff --git a/integrations/acquisition/covidcast/test_covidcast_meta_caching.py b/integrations/acquisition/covidcast/test_covidcast_meta_caching.py index 99008a0f1..e746c4ef1 100644 --- a/integrations/acquisition/covidcast/test_covidcast_meta_caching.py +++ b/integrations/acquisition/covidcast/test_covidcast_meta_caching.py @@ -60,12 +60,20 @@ def setUp(self): # use the local instance of the Epidata API Epidata.BASE_URL = BASE_URL + Epidata.auth = ('epidata', 'key') def tearDown(self): """Perform per-test teardown.""" self.cur.close() self.cnx.close() + @staticmethod + def _make_request(): + params = {'endpoint': 'covidcast_meta', 'cached': 'true'} + response = requests.get(Epidata.BASE_URL, params=params, auth=Epidata.auth) + response.raise_for_status() + return response.json() + def test_caching(self): """Populate, query, cache, query, and verify the cache.""" @@ -147,10 +155,7 @@ def test_caching(self): self.cnx.commit() # fetch the cached version (manually) - params = {'endpoint': 'covidcast_meta', 'cached': 'true'} - response = requests.get(BASE_URL, params=params) - response.raise_for_status() - epidata4 = response.json() + epidata4 = self._make_request() # make sure the cache was actually served self.assertEqual(epidata4, { @@ -170,10 +175,7 @@ def test_caching(self): self.cnx.commit() # fetch the cached version (manually) - params = {'endpoint': 'covidcast_meta', 'cached': 'true'} - response = requests.get(BASE_URL, params=params) - response.raise_for_status() - epidata5 = response.json() + epidata5 = self._make_request() # make sure the cache was returned anyhow self.assertEqual(epidata4, epidata5) diff --git a/integrations/acquisition/covidcast/test_csv_uploading.py b/integrations/acquisition/covidcast/test_csv_uploading.py index f975ecfa0..e2db98cf3 100644 --- a/integrations/acquisition/covidcast/test_csv_uploading.py +++ b/integrations/acquisition/covidcast/test_csv_uploading.py @@ -57,6 +57,7 @@ def setUp(self): # use the local instance of the Epidata API Epidata.BASE_URL = 'http://delphi_web_epidata/epidata/api.php' + Epidata.auth = ('epidata', 'key') def tearDown(self): """Perform per-test teardown.""" diff --git a/integrations/acquisition/covidcast_nowcast/test_csv_uploading.py b/integrations/acquisition/covidcast_nowcast/test_csv_uploading.py index 9dc163a2b..c9bbf77ce 100644 --- a/integrations/acquisition/covidcast_nowcast/test_csv_uploading.py +++ b/integrations/acquisition/covidcast_nowcast/test_csv_uploading.py @@ -41,6 +41,8 @@ def setUp(self): database='epidata') cur = cnx.cursor() cur.execute('truncate table covidcast_nowcast') + cur.execute('delete from api_user') + cur.execute('insert into api_user(api_key, tracking, registered) values ("key", 1, 1)') cnx.commit() cur.close() @@ -54,6 +56,7 @@ def setUp(self): # use the local instance of the Epidata API Epidata.BASE_URL = 'http://delphi_web_epidata/epidata/api.php' + Epidata.auth = ('epidata', 'key') def tearDown(self): """Perform per-test teardown.""" diff --git a/integrations/client/test_delphi_epidata.py b/integrations/client/test_delphi_epidata.py index 82c1452ec..8072b88ee 100644 --- a/integrations/client/test_delphi_epidata.py +++ b/integrations/client/test_delphi_epidata.py @@ -41,6 +41,7 @@ def localSetUp(self): # use the local instance of the Epidata API Epidata.BASE_URL = 'http://delphi_web_epidata/epidata/api.php' + Epidata.auth = ('epidata', 'key') # use the local instance of the epidata database secrets.db.host = 'delphi_database_epidata' diff --git a/integrations/client/test_nowcast.py b/integrations/client/test_nowcast.py index dc1a20794..6860542cb 100644 --- a/integrations/client/test_nowcast.py +++ b/integrations/client/test_nowcast.py @@ -28,6 +28,8 @@ def setUp(self): cur = cnx.cursor() cur.execute('truncate table covidcast_nowcast') + cur.execute('delete from api_user') + cur.execute('insert into api_user(api_key, tracking, registered) values ("key", 1, 1)') cnx.commit() cur.close() @@ -38,6 +40,7 @@ def setUp(self): # use the local instance of the Epidata API Epidata.BASE_URL = 'http://delphi_web_epidata/epidata/api.php' + Epidata.auth = ('epidata', 'key') # use the local instance of the epidata database secrets.db.host = 'delphi_database_epidata' diff --git a/integrations/server/test_covid_hosp.py b/integrations/server/test_covid_hosp.py index 16538b82d..ba27698e8 100644 --- a/integrations/server/test_covid_hosp.py +++ b/integrations/server/test_covid_hosp.py @@ -17,6 +17,7 @@ def setUp(self): # use the local instance of the Epidata API Epidata.BASE_URL = 'http://delphi_web_epidata/epidata/api.php' + Epidata.auth = ('epidata', 'key') # use the local instance of the epidata database secrets.db.host = 'delphi_database_epidata' @@ -27,6 +28,8 @@ def setUp(self): with db.new_cursor() as cur: cur.execute('truncate table covid_hosp_state_timeseries') cur.execute('truncate table covid_hosp_meta') + cur.execute('delete from api_user') + cur.execute('insert into api_user(api_key, tracking, registered) values ("key", 1, 1)') def insert_issue(self, cur, issue, value, record_type): diff --git a/integrations/server/test_covidcast.py b/integrations/server/test_covidcast.py index 5a8df96f0..582230b40 100644 --- a/integrations/server/test_covidcast.py +++ b/integrations/server/test_covidcast.py @@ -6,15 +6,12 @@ # third party import mysql.connector -import requests # first party from delphi_utils import Nans from delphi.epidata.acquisition.covidcast.test_utils import CovidcastBase, CovidcastTestRow from delphi.epidata.client.delphi_epidata import Epidata -# use the local instance of the Epidata API -BASE_URL = 'http://delphi_web_epidata/epidata/api.php' class CovidcastTests(CovidcastBase): """Tests the `covidcast` endpoint.""" @@ -25,7 +22,9 @@ def localSetUp(self): def request_based_on_row(self, row: CovidcastTestRow, **kwargs): params = self.params_from_row(row, endpoint='covidcast', **kwargs) - Epidata.BASE_URL = BASE_URL + # use the local instance of the Epidata API + Epidata.BASE_URL = 'http://delphi_web_epidata/epidata/api.php' + Epidata.auth = ('epidata', 'key') response = Epidata.covidcast(**params) return response diff --git a/integrations/server/test_covidcast_endpoints.py b/integrations/server/test_covidcast_endpoints.py index 1f7e7ade5..41b74ac03 100644 --- a/integrations/server/test_covidcast_endpoints.py +++ b/integrations/server/test_covidcast_endpoints.py @@ -15,6 +15,7 @@ # use the local instance of the Epidata API BASE_URL = "http://delphi_web_epidata/epidata/covidcast" BASE_URL_OLD = "http://delphi_web_epidata/epidata/api.php" +AUTH = ('epidata', 'key') class CovidcastEndpointTests(CovidcastBase): @@ -36,7 +37,7 @@ def _fetch(self, endpoint="/", is_compatibility=False, **params): params.setdefault("data_source", params.get("source")) else: url = f"{BASE_URL}{endpoint}" - response = requests.get(url, params=params) + response = requests.get(url, params=params, auth=AUTH) response.raise_for_status() return response.json() diff --git a/integrations/server/test_covidcast_meta.py b/integrations/server/test_covidcast_meta.py index d0aef6fe5..95a51e354 100644 --- a/integrations/server/test_covidcast_meta.py +++ b/integrations/server/test_covidcast_meta.py @@ -14,6 +14,7 @@ # use the local instance of the Epidata API BASE_URL = 'http://delphi_web_epidata/epidata/api.php' +AUTH = ('epidata', 'key') class CovidcastMetaTests(unittest.TestCase): @@ -135,6 +136,14 @@ def insert_placeholder_data(self): def _get_id(self): self.id_counter += 1 return self.id_counter + + @staticmethod + def _fetch(**kwargs): + params = kwargs.copy() + params['endpoint'] = 'covidcast_meta' + response = requests.get(BASE_URL, params=params, auth=AUTH) + response.raise_for_status() + return response.json() def test_round_trip(self): """Make a simple round-trip with some sample data.""" @@ -143,9 +152,7 @@ def test_round_trip(self): expected = self.insert_placeholder_data() # make the request - response = requests.get(BASE_URL, params={'endpoint': 'covidcast_meta'}) - response.raise_for_status() - response = response.json() + response = self._fetch() # assert that the right data came back self.assertEqual(response, { @@ -160,71 +167,63 @@ def test_filter(self): # insert placeholder data and accumulate expected results (in sort order) expected = self.insert_placeholder_data() - def fetch(**kwargs): - # make the request - params = kwargs.copy() - params['endpoint'] = 'covidcast_meta' - response = requests.get(BASE_URL, params=params) - response.raise_for_status() - return response.json() - - res = fetch() + res = self._fetch() self.assertEqual(res['result'], 1) self.assertEqual(len(res['epidata']), len(expected)) # time types - res = fetch(time_types='day') + res = self._fetch(time_types='day') self.assertEqual(res['result'], 1) self.assertEqual(len(res['epidata']), sum([1 for s in expected if s['time_type'] == 'day'])) - res = fetch(time_types='day,week') + res = self._fetch(time_types='day,week') self.assertEqual(res['result'], 1) self.assertTrue(isinstance(res['epidata'], list)) self.assertEqual(len(res['epidata']), len(expected)) - res = fetch(time_types='sec') + res = self._fetch(time_types='sec') self.assertEqual(res['result'], -2) # geo types - res = fetch(geo_types='hrr') + res = self._fetch(geo_types='hrr') self.assertEqual(res['result'], 1) self.assertTrue(isinstance(res['epidata'], list)) self.assertEqual(len(res['epidata']), sum([1 for s in expected if s['geo_type'] == 'hrr'])) - res = fetch(geo_types='hrr,msa') + res = self._fetch(geo_types='hrr,msa') self.assertEqual(res['result'], 1) self.assertTrue(isinstance(res['epidata'], list)) self.assertEqual(len(res['epidata']), len(expected)) - res = fetch(geo_types='state') + res = self._fetch(geo_types='state') self.assertEqual(res['result'], -2) # signals - res = fetch(signals='src1:sig1') + res = self._fetch(signals='src1:sig1') self.assertEqual(res['result'], 1) self.assertTrue(isinstance(res['epidata'], list)) self.assertEqual(len(res['epidata']), sum([1 for s in expected if s['data_source'] == 'src1' and s['signal'] == 'sig1'])) - res = fetch(signals='src1') + res = self._fetch(signals='src1') self.assertEqual(res['result'], 1) self.assertTrue(isinstance(res['epidata'], list)) self.assertEqual(len(res['epidata']), sum([1 for s in expected if s['data_source'] == 'src1'])) - res = fetch(signals='src1:*') + res = self._fetch(signals='src1:*') self.assertEqual(res['result'], 1) self.assertTrue(isinstance(res['epidata'], list)) self.assertEqual(len(res['epidata']), sum([1 for s in expected if s['data_source'] == 'src1'])) - res = fetch(signals='src1:src4') + res = self._fetch(signals='src1:src4') self.assertEqual(res['result'], -2) - res = fetch(signals='src1:*,src2:*') + res = self._fetch(signals='src1:*,src2:*') self.assertEqual(res['result'], 1) self.assertTrue(isinstance(res['epidata'], list)) self.assertEqual(len(res['epidata']), len(expected)) # filter fields - res = fetch(fields='data_source,min_time') + res = self._fetch(fields='data_source,min_time') self.assertEqual(res['result'], 1) self.assertEqual(len(res['epidata']), len(expected)) self.assertTrue('data_source' in res['epidata'][0]) @@ -232,7 +231,7 @@ def fetch(**kwargs): self.assertFalse('max_time' in res['epidata'][0]) self.assertFalse('signal' in res['epidata'][0]) - res = fetch(fields='xx') + res = self._fetch(fields='xx') self.assertEqual(res['result'], 1) self.assertEqual(len(res['epidata']), len(expected)) self.assertEqual(res['epidata'][0], {}) diff --git a/integrations/server/test_covidcast_nowcast.py b/integrations/server/test_covidcast_nowcast.py index 7df695038..2163010d6 100644 --- a/integrations/server/test_covidcast_nowcast.py +++ b/integrations/server/test_covidcast_nowcast.py @@ -10,6 +10,7 @@ # use the local instance of the Epidata API BASE_URL = 'http://delphi_web_epidata/epidata/api.php' +AUTH = ('epidata', 'key') class CovidcastTests(unittest.TestCase): @@ -26,6 +27,8 @@ def setUp(self): database='epidata') cur = cnx.cursor() cur.execute('truncate table covidcast_nowcast') + cur.execute('delete from api_user') + cur.execute('insert into api_user(api_key, tracking, registered) values("key", 1, 1)') cnx.commit() cur.close() @@ -38,6 +41,12 @@ def tearDown(self): self.cur.close() self.cnx.close() + @staticmethod + def _make_request(params: dict): + response = requests.get(BASE_URL, params=params, auth=AUTH) + response.raise_for_status() + return response.json() + def test_query(self): """Query nowcasts using default and specified issue.""" @@ -49,7 +58,7 @@ def test_query(self): self.cnx.commit() # make the request with specified issue date - response = requests.get(BASE_URL, params={ + params={ 'source': 'covidcast_nowcast', 'data_source': 'src', 'signals': 'sig', @@ -59,9 +68,8 @@ def test_query(self): 'time_values': 20200101, 'geo_value': '01001', 'issues': 20200101 - }) - response.raise_for_status() - response = response.json() + } + response = self._make_request(params=params) self.assertEqual(response, { 'result': 1, 'epidata': [{ @@ -76,7 +84,7 @@ def test_query(self): }) # make request without specific issue date - response = requests.get(BASE_URL, params={ + params={ 'source': 'covidcast_nowcast', 'data_source': 'src', 'signals': 'sig', @@ -85,9 +93,8 @@ def test_query(self): 'geo_type': 'county', 'time_values': 20200101, 'geo_value': '01001', - }) - response.raise_for_status() - response = response.json() + } + response = self._make_request(params=params) self.assertEqual(response, { 'result': 1, @@ -102,7 +109,7 @@ def test_query(self): 'message': 'success', }) - response = requests.get(BASE_URL, params={ + params={ 'source': 'covidcast_nowcast', 'data_source': 'src', 'signals': 'sig', @@ -112,9 +119,8 @@ def test_query(self): 'time_values': 20200101, 'geo_value': '01001', 'as_of': 20200101 - }) - response.raise_for_status() - response = response.json() + } + response = self._make_request(params=params) self.assertEqual(response, { 'result': 1, diff --git a/integrations/server/test_fluview.py b/integrations/server/test_fluview.py index 8bfc18376..152bb3883 100644 --- a/integrations/server/test_fluview.py +++ b/integrations/server/test_fluview.py @@ -19,6 +19,7 @@ def setUpClass(cls): # use the local instance of the Epidata API Epidata.BASE_URL = 'http://delphi_web_epidata/epidata/api.php' + Epidata.auth = ('epidata', 'key') def setUp(self): """Perform per-test setup.""" @@ -31,6 +32,8 @@ def setUp(self): database='epidata') cur = cnx.cursor() cur.execute('truncate table fluview') + cur.execute('delete from api_user') + cur.execute('insert into api_user(api_key, tracking, registered) values ("key", 1, 1)') cnx.commit() cur.close() diff --git a/integrations/server/test_fluview_meta.py b/integrations/server/test_fluview_meta.py index 137e9464a..9c1cb9b6d 100644 --- a/integrations/server/test_fluview_meta.py +++ b/integrations/server/test_fluview_meta.py @@ -19,6 +19,7 @@ def setUpClass(cls): # use the local instance of the Epidata API Epidata.BASE_URL = 'http://delphi_web_epidata/epidata/api.php' + Epidata.auth = ('epidata', 'key') def setUp(self): """Perform per-test setup.""" @@ -31,6 +32,8 @@ def setUp(self): database='epidata') cur = cnx.cursor() cur.execute('truncate table fluview') + cur.execute('delete from api_user') + cur.execute('insert into api_user(api_key, tracking, registered) values ("key", 1, 1)') cnx.commit() cur.close() diff --git a/requirements.api.txt b/requirements.api.txt index 6ccafc1e1..d35b66a95 100644 --- a/requirements.api.txt +++ b/requirements.api.txt @@ -14,3 +14,6 @@ structlog==22.1.0 tenacity==7.0.0 typing-extensions werkzeug==2.2.2 +Flask-Limiter==1.4 +redis==3.5.3 +requests==2.28.1 diff --git a/src/client/delphi_epidata.py b/src/client/delphi_epidata.py index 9b3deea94..8f3f99d39 100644 --- a/src/client/delphi_epidata.py +++ b/src/client/delphi_epidata.py @@ -13,7 +13,7 @@ import asyncio from tenacity import retry, stop_after_attempt -from aiohttp import ClientSession, TCPConnector +from aiohttp import ClientSession, TCPConnector, BasicAuth from pkg_resources import get_distribution, DistributionNotFound # Obtain package version for the user-agent. Uses the installed version by @@ -34,6 +34,7 @@ class Epidata: # API base url BASE_URL = 'https://delphi.cmu.edu/epidata/api.php' + auth = None client_version = _version @@ -58,9 +59,9 @@ def _list(values): @retry(reraise=True, stop=stop_after_attempt(2)) def _request_with_retry(params): """Make request with a retry if an exception is thrown.""" - req = requests.get(Epidata.BASE_URL, params, headers=_HEADERS) + req = requests.get(Epidata.BASE_URL, params, auth=Epidata.auth, headers=_HEADERS) if req.status_code == 414: - req = requests.post(Epidata.BASE_URL, params, headers=_HEADERS) + req = requests.post(Epidata.BASE_URL, params, auth=Epidata.auth, headers=_HEADERS) return req @staticmethod @@ -736,7 +737,8 @@ async def async_make_calls(param_combos): """Helper function to asynchronously make and aggregate Epidata GET requests.""" tasks = [] connector = TCPConnector(limit=batch_size) - async with ClientSession(connector=connector, headers=_HEADERS) as session: + auth = BasicAuth(login=Epidata.auth[0], password=Epidata.auth[1], encoding='utf-8') + async with ClientSession(connector=connector, headers=_HEADERS, auth=auth) as session: for param in param_combos: task = asyncio.ensure_future(async_get(param, session)) tasks.append(task) diff --git a/src/ddl/api_user.sql b/src/ddl/api_user.sql new file mode 100644 index 000000000..11fa6feba --- /dev/null +++ b/src/ddl/api_user.sql @@ -0,0 +1,26 @@ +USE epidata; +/* +`api_user` API key and user management +This data is private to Delphi. ++----------------------+---------------+------+-----+-------------------+----------------+ +| Field | Type | Null | Key | Default | Extra | ++----------------------+---------------+------+-----+-------------------+----------------+ +| id | int(11) | NO | PRI | | auto_increment | +| api_key | varchar(50) | NO | | | unique | +| tracking | tinyint(1) | YES | | | | +| registered | tinyint(1) | YES | | | | +| creation_date | datetime | NO | | current_timestamp | | +| last_api_access_date | datetime | NO | | current_timestamp | | ++------------+---------------+------+-----+---------+------------------------------------+ +*/ + +CREATE TABLE IF NOT EXISTS `api_user` ( + `id` int(11) NOT NULL PRIMARY KEY AUTO_INCREMENT, + `api_key` varchar(50) NOT NULL , + `email` varchar(50) UNIQUE NULL, + `tracking` tinyint(1) NULL, + `registered` tinyint(1) NULL, + `created` date, + `last_time_used` date, + UNIQUE KEY `api_key` (`api_key`, `email`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; \ No newline at end of file diff --git a/src/ddl/user_role.sql b/src/ddl/user_role.sql new file mode 100644 index 000000000..21d344bc5 --- /dev/null +++ b/src/ddl/user_role.sql @@ -0,0 +1,17 @@ +USE epidata; +/* +`user_roles` User roles +This data is private to Delphi. ++------------+---------------+------+-----+---------+----------------+ +| Field | Type | Null | Key | Default | Extra | ++------------+---------------+------+-----+---------+----------------+ +| id | int(11) | NO | PRI | | auto_increment | +| name | varchar(50) | NO | | | unique | ++------------+---------------+------+-----+---------+----------------+ +*/ + +CREATE TABLE IF NOT EXISTS `user_role` ( + `id` int(11) NOT NULL PRIMARY KEY AUTO_INCREMENT, + `name` varchar(50) NOT NULL, + UNIQUE KEY `name` (`name`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; \ No newline at end of file diff --git a/src/ddl/user_role_link.sql b/src/ddl/user_role_link.sql new file mode 100644 index 000000000..6da0a3ae4 --- /dev/null +++ b/src/ddl/user_role_link.sql @@ -0,0 +1,17 @@ +USE epidata; +/* +`user_roles` User roles +This data is private to Delphi. ++------------+---------------+------+-----+---------+----------------+ +| Field | Type | Null | Key | Default | Extra | ++------------+---------------+------+-----+---------+----------------+ +| user_id | int(11) | NO | PRI | | | +| role_id | int(11) | NO | PRI | | | ++------------+---------------+------+-----+---------+----------------+ +*/ + +CREATE TABLE IF NOT EXISTS `user_role_link` ( + `user_id` int(11) NOT NULL, + `role_id` int(11) NOT NULL, + PRIMARY KEY (`user_id`, `role_id`) +) ENGINE=InnoDB DEFAULT CHARSET=utf8; \ No newline at end of file diff --git a/src/server/_common.py b/src/server/_common.py index 2d2d3059f..ed59669e3 100644 --- a/src/server/_common.py +++ b/src/server/_common.py @@ -13,6 +13,9 @@ engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS) app = Flask("EpiData", static_url_path="") +# for example if the request goes through one proxy +# before hitting your application server +# app.wsgi_app = ProxyFix(app.wsgi_app, num_proxies=1) app.config["SECRET"] = SECRET diff --git a/src/server/_config.py b/src/server/_config.py index 618407f75..ce5d40b77 100644 --- a/src/server/_config.py +++ b/src/server/_config.py @@ -1,6 +1,9 @@ import json import os from dotenv import load_dotenv +import json +from enum import Enum +from datetime import date load_dotenv() @@ -13,8 +16,8 @@ # defaults SQLALCHEMY_ENGINE_OPTIONS = { - "pool_pre_ping": True, # enable ping test for validity of recycled pool connections on connect() calls - "pool_recycle": 5 # seconds after which a recycled pool connection is considered invalid + "pool_pre_ping": True, # enable ping test for validity of recycled pool connections on connect() calls + "pool_recycle": 5 # seconds after which a recycled pool connection is considered invalid } # update with overrides of defaults or additions from external configs SQLALCHEMY_ENGINE_OPTIONS.update( @@ -23,36 +26,6 @@ SECRET = os.environ.get("FLASK_SECRET", "secret") URL_PREFIX = os.environ.get("FLASK_PREFIX", "/") -AUTH = { - "twitter": os.environ.get("SECRET_TWITTER"), - "ght": os.environ.get("SECRET_GHT"), - "fluview": os.environ.get("SECRET_FLUVIEW"), - "cdc": os.environ.get("SECRET_CDC"), - "sensors": os.environ.get("SECRET_SENSORS"), - "quidel": os.environ.get("SECRET_QUIDEL"), - "norostat": os.environ.get("SECRET_NOROSTAT"), - "afhsb": os.environ.get("SECRET_AFHSB"), -} - -# begin sensor query authentication configuration -# A multimap of sensor names to the "granular" auth tokens that can be used to access them; excludes the "global" sensor auth key that works for all sensors: -GRANULAR_SENSOR_AUTH_TOKENS = { - "twtr": os.environ.get("SECRET_SENSOR_TWTR", "").split(","), - "gft": os.environ.get("SECRET_SENSOR_GFT", "").split(","), - "ght": os.environ.get("SECRET_SENSOR_GHT", "").split(","), - "ghtj": os.environ.get("SECRET_SENSOR_GHTJ", "").split(","), - "cdc": os.environ.get("SECRET_SENSOR_CDC", "").split(","), - "quid": os.environ.get("SECRET_SENSOR_QUID", "").split(","), - "wiki": os.environ.get("SECRET_SENSOR_WIKI", "").split(","), -} - -# A set of sensors that do not require an auth key to access: -OPEN_SENSORS = [ - "sar3", - "epic", - "arch", -] - REGION_TO_STATE = { "hhs1": ["VT", "CT", "ME", "MA", "NH", "RI"], "hhs2": ["NJ", "NY"], @@ -75,3 +48,60 @@ "cen9": ["AK", "CA", "HI", "OR", "WA"], } NATION_REGION = "nat" + +API_KEY_REQUIRED_STARTING_AT = date.fromisoformat(os.environ.get('API_REQUIRED_STARTING_AT', '3000-01-01')) +# password needed for the admin application if not set the admin routes won't be available +ADMIN_PASSWORD = os.environ.get('API_KEY_ADMIN_PASSWORD') +# secret for the google form to give to the admin/register endpoint +REGISTER_WEBHOOK_TOKEN = os.environ.get('API_KEY_REGISTER_WEBHOOK_TOKEN') +# see recaptcha +RECAPTCHA_SITE_KEY = os.environ.get('RECAPTCHA_SITE_KEY') +RECAPTCHA_SECRET_KEY = os.environ.get('RECAPTCHA_SECRET_KEY') + +# https://flask-limiter.readthedocs.io/en/stable/#rate-limit-string-notation +RATE_LIMIT = os.environ.get('RATE_LIMIT', '10/hour') +# fixed-window, fixed-window-elastic-expiry, or moving-window +# see also https://flask-limiter.readthedocs.io/en/stable/#rate-limiting-strategies +RATELIMIT_STRATEGY = os.environ.get('RATELIMIT_STRATEGY', 'fixed-window') +# see https://flask-limiter.readthedocs.io/en/stable/#configuration +RATELIMIT_STORAGE_URL = os.environ.get('RATELIMIT_STORAGE_URL', 'memory://') + +REDIS_HOST = os.environ.get('REDIS_HOST', "delphi_redis_instance") + + +class UserRole(str, Enum): + afhsb = "afhsb" + cdc = "cdc" + fluview = "fluview" + ght = "ght" + norostat = "norostat" + quidel = "quidel" + sensors = "sensors" + sensor_twtr = "sensor_twtr" + sensor_gft = "sensor_gft" + sensor_ght = "sensor_ght" + sensor_ghtj = "sensor_ghtj" + sensor_cdc = "sensor_cdc" + sensor_quid = "sensor_quid" + sensor_wiki = "sensor_wiki" + twitter = "twitter" + +# Begin sensor query authentication configuration +# A multimap of sensor names to the "granular" auth tokens that can be used to access them; +# excludes the "global" sensor auth key that works for all sensors: +GRANULAR_SENSOR_ROLES = { + "twtr": UserRole.sensor_twtr, + "gft": UserRole.sensor_gft, + "ght": UserRole.sensor_ght, + "ghtj": UserRole.sensor_ghtj, + "cdc": UserRole.sensor_cdc, + "quid": UserRole.sensor_quid, + "wiki": UserRole.sensor_wiki, +} + +# A set of sensors that do not require an auth key to access: +OPEN_SENSORS = [ + "sar3", + "epic", + "arch", +] \ No newline at end of file diff --git a/src/server/_db.py b/src/server/_db.py new file mode 100644 index 000000000..2398e2969 --- /dev/null +++ b/src/server/_db.py @@ -0,0 +1,33 @@ +from typing import Dict, List +from sqlalchemy import MetaData, create_engine, inspect +from sqlalchemy.engine import Engine +from sqlalchemy.engine.reflection import Inspector +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +from ._config import SQLALCHEMY_DATABASE_URI, SQLALCHEMY_ENGINE_OPTIONS + +engine: Engine = create_engine(SQLALCHEMY_DATABASE_URI, **SQLALCHEMY_ENGINE_OPTIONS) +metadata = MetaData(bind=engine) + +Base = declarative_base() +Session = sessionmaker(bind=engine) +session = Session() + + +TABLE_OPTIONS = dict( + mysql_engine="InnoDB", + # mariadb_engine="InnoDB", + mysql_charset="utf8mb4", + # mariadb_charset="utf8", +) + + +def sql_table_has_columns(table: str, columns: List[str]) -> bool: + """ + checks whether the given table has all the given columns defined + """ + inspector: Inspector = inspect(engine) + table_columns: List[Dict] = inspector.get_columns(table) + table_column_names = set(str(d.get("name", "")).lower() for d in table_columns) + return all(c.lower() in table_column_names for c in columns) diff --git a/src/server/_exceptions.py b/src/server/_exceptions.py index 835bfc118..942be1290 100644 --- a/src/server/_exceptions.py +++ b/src/server/_exceptions.py @@ -30,6 +30,11 @@ def __init__(self): super(UnAuthenticatedException, self).__init__("unauthenticated", 401) +class MissingAPIKeyException(EpiDataException): + def __init__(self): + super(MissingAPIKeyException, self).__init__("missing api key", 401) + + class ValidationFailedException(EpiDataException): def __init__(self, message: str): super(ValidationFailedException, self).__init__(message, 400) diff --git a/src/server/_printer.py b/src/server/_printer.py index 162ba2e36..8bd80a298 100644 --- a/src/server/_printer.py +++ b/src/server/_printer.py @@ -7,6 +7,7 @@ import orjson from ._config import MAX_RESULTS, MAX_COMPATIBILITY_RESULTS +from ._security import show_hard_api_key_warning, show_soft_api_key_warning, API_KEY_WARNING_TEXT from ._common import is_compatibility_mode from delphi.epidata.common.logger import get_structured_logger @@ -22,7 +23,7 @@ def print_non_standard(format: str, data): message = "no results" result = -2 else: - message = "success" + message = API_KEY_WARNING_TEXT if show_soft_api_key_warning() else "success" result = 1 if result == -1 and is_compatibility_mode(): return jsonify(dict(result=result, message=message)) @@ -112,21 +113,24 @@ class ClassicPrinter(APrinter): """ def _begin(self): - if is_compatibility_mode(): + if is_compatibility_mode() and not show_hard_api_key_warning(): return "{ " - return '{ "epidata": [' + r = '{ "epidata": [' + if show_hard_api_key_warning(): + r = f'{r} "{API_KEY_WARNING_TEXT}" ' + return r def _format_row(self, first: bool, row: Dict): - if first and is_compatibility_mode(): + if first and is_compatibility_mode() and not show_hard_api_key_warning(): sep = b'"epidata": [' else: - sep = b"," if not first else b"" + sep = b"," if not first or show_hard_api_key_warning() else b"" return sep + orjson.dumps(row) def _end(self): - message = "success" + message = API_KEY_WARNING_TEXT if show_soft_api_key_warning() else "success" prefix = "], " - if self.count == 0 and is_compatibility_mode(): + if self.count == 0 and is_compatibility_mode() and not show_hard_api_key_warning(): # no array to end prefix = "" @@ -160,7 +164,7 @@ def _format_row(self, first: bool, row: Dict): self._tree[group].append(row) else: self._tree[group] = [row] - if first and is_compatibility_mode(): + if first and is_compatibility_mode() and not show_hard_api_key_warning(): return b'"epidata": [' return None @@ -171,7 +175,10 @@ def _end(self): tree = orjson.dumps(self._tree) self._tree = dict() r = super(ClassicTreePrinter, self)._end() - return tree + r + r = tree + r + if show_hard_api_key_warning(): + r = b", " + r + return r class CSVPrinter(APrinter): @@ -200,8 +207,12 @@ def _error(self, error: Exception) -> str: def _format_row(self, first: bool, row: Dict): if first: - self._writer = DictWriter(self._stream, list(row.keys()), lineterminator="\n") + columns = list(row.keys()) + self._writer = DictWriter(self._stream, columns, lineterminator="\n") self._writer.writeheader() + if show_hard_api_key_warning() and columns: + self._writer.writerow({columns[0]: API_KEY_WARNING_TEXT}) + self._writer.writerow(row) # remove the stream content to print just one line at a time @@ -222,10 +233,13 @@ class JSONPrinter(APrinter): """ def _begin(self): - return b"[" + r = b"[" + if show_hard_api_key_warning(): + r = b'["' + bytes(API_KEY_WARNING_TEXT, "utf-8") + b'"' + return r def _format_row(self, first: bool, row: Dict): - sep = b"," if not first else b"" + sep = b"," if not first or show_hard_api_key_warning() else b"" return sep + orjson.dumps(row) def _end(self): @@ -240,6 +254,11 @@ class JSONLPrinter(APrinter): def make_response(self, gen): return Response(gen, mimetype=" text/plain; charset=utf8") + def _begin(self): + if show_hard_api_key_warning(): + return bytes(API_KEY_WARNING_TEXT, "utf-8") + b"\n" + return None + def _format_row(self, first: bool, row: Dict): # each line is a JSON file with a new line to separate them return orjson.dumps(row, option=orjson.OPT_APPEND_NEWLINE) diff --git a/src/server/_security.py b/src/server/_security.py new file mode 100644 index 000000000..1d4b36c84 --- /dev/null +++ b/src/server/_security.py @@ -0,0 +1,192 @@ +import re +from datetime import date, timedelta, datetime +from functools import wraps +from typing import Optional, cast +from uuid import uuid4 + +from flask import Response, g, request +from flask_limiter import Limiter +from flask_limiter.util import get_remote_address +from werkzeug.local import LocalProxy +import redis + +from ._common import app +from ._config import (API_KEY_REQUIRED_STARTING_AT, RATELIMIT_STORAGE_URL, + URL_PREFIX, REDIS_HOST) +from ._exceptions import MissingAPIKeyException, UnAuthenticatedException +from .admin.models import User, UserRole +# from ._logger import get_structured_logger + +API_KEY_HARD_WARNING = API_KEY_REQUIRED_STARTING_AT - timedelta(days=14) +API_KEY_SOFT_WARNING = API_KEY_HARD_WARNING - timedelta(days=14) + +API_KEY_WARNING_TEXT = ( + "an api key will be required starting at {}, go to https://delphi.cmu.edu to request one".format( + API_KEY_REQUIRED_STARTING_AT + ) +) + +TESTING_MODE = app.config.get("TESTING", False) + + +# TODO: should be fixed +# def log_info(user: User, msg: str, *args, **kwargs) -> None: +# logger = get_structured_logger("api_key_logs", filename="api_keys_log.log") +# if user.is_authenticated: +# if user.tracking: +# logger.info(msg, *args, **dict(kwargs, api_key=user.api_key)) +# else: +# logger.info(msg, *args, **dict(kwargs, apikey="*****")) +# else: +# logger.info(msg, *args, **kwargs) + + +def resolve_auth_token() -> Optional[str]: + for n in ("auth", "api_key", "token"): + if n in request.values: + return request.values[n] + # username password + if request.authorization and request.authorization.username == "epidata": + return request.authorization.password + # bearer token authentication + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + return auth_header[len("Bearer "):] + return None + + +def register_new_key() -> str: + api_key = str(uuid4()) + User.create_user(api_key=api_key) + return api_key + + +def mask_apikey(path: str) -> str: + # Function to mask API key query string from a request path + regexp = re.compile(r"[\\?&]api_key=([^&#]*)") + if regexp.search(path): + path = re.sub(regexp, "&api_key=*****", path) + return path + + +def require_api_key() -> bool: + n = date.today() + return n >= API_KEY_REQUIRED_STARTING_AT and not TESTING_MODE + + +def _get_current_user(): + if "user" not in g: + api_key = resolve_auth_token() + user = User.find_user(api_key=api_key) + request_path = request.full_path + if not user.is_authenticated: + if require_api_key(): + raise MissingAPIKeyException + if not user.tracking: + request_path = mask_apikey(request_path) + # TODO: add logging + # log_info(user, "Get path", path=request_path) + g.user = user + return g.user + + +current_user: User = cast(User, LocalProxy(_get_current_user)) + + +def show_soft_api_key_warning() -> bool: + n = date.today() + return not current_user.id and not TESTING_MODE and API_KEY_SOFT_WARNING < n < API_KEY_HARD_WARNING + + +def show_hard_api_key_warning() -> bool: + n = date.today() + return not current_user.id and n > API_KEY_HARD_WARNING and not TESTING_MODE + + +def register_user_role(role_name: str) -> None: + UserRole.create_role(role_name) + + +def _is_public_route() -> bool: + public_routes_list = ["lib", "admin", "version"] + for route in public_routes_list: + if request.path.startswith(f"{URL_PREFIX}/{route}"): + return True + return False + + +@app.before_request +def resolve_user(): + if _is_public_route(): + return + # try to get the db + try: + _get_current_user() + except MissingAPIKeyException as e: + raise e + except UnAuthenticatedException as e: + raise e + except: + app.logger.error("user connection error", exc_info=True) + if require_api_key(): + raise MissingAPIKeyException() + else: + g.user = User("anonymous") + + +def require_role(required_role: str): + def decorator_wrapper(f): + if not required_role: + return f + + @wraps(f) + def decorated_function(*args, **kwargs): + if not current_user or not current_user.has_role(required_role): + raise UnAuthenticatedException() + return f(*args, **kwargs) + + return decorated_function + + return decorator_wrapper + + +def _resolve_tracking_key() -> str: + token = resolve_auth_token() + return token or get_remote_address() + + +def deduct_on_success(response: Response) -> bool: + if response.status_code != 200: + return False + # check if we have the classic format + if not response.is_streamed and response.is_json: + out = response.json + if out and isinstance(out, dict) and out.get("result") == -1: + return False + return True + + +limiter = Limiter(app, key_func=_resolve_tracking_key, storage_uri=RATELIMIT_STORAGE_URL) + + +@limiter.request_filter +def _no_rate_limit() -> bool: + if TESTING_MODE or _is_public_route(): + return False + # no rate limit if user is registered + user = _get_current_user() + return user is not None and user.registered # type: ignore + + +@app.after_request +def update_key_last_time_used(response): + if _is_public_route(): + return response + try: + r = redis.Redis(host=REDIS_HOST) + api_key = g.user.api_key + r.set(f"LAST_USED/{api_key}", datetime.strftime(datetime.now(), "%Y-%m-%d")) + except Exception as e: + print(e) # TODO: should be handled properly + finally: + return response diff --git a/src/server/_validate.py b/src/server/_validate.py index 957bee09d..37adc470f 100644 --- a/src/server/_validate.py +++ b/src/server/_validate.py @@ -1,37 +1,6 @@ -from typing import Optional - from flask import Request -from ._exceptions import UnAuthenticatedException, ValidationFailedException - - -def resolve_auth_token(request: Request) -> Optional[str]: - # auth request param - if "auth" in request.values: - return request.values["auth"] - # user name password - if request.authorization and request.authorization.username == "epidata": - return request.authorization.password - # bearer token authentication - auth_header = request.headers.get("Authorization") - if auth_header and auth_header.startswith("Bearer "): - return auth_header[len("Bearer ") :] - return None - - -def check_auth_token(request: Request, token: str, optional=False) -> bool: - value = resolve_auth_token(request) - - if value is None: - if optional: - return False - else: - raise ValidationFailedException(f"missing parameter: auth") - - valid_token = value == token - if not valid_token and not optional: - raise UnAuthenticatedException() - return valid_token +from ._exceptions import ValidationFailedException def require_all(request: Request, *values: str) -> bool: diff --git a/src/server/admin/__init__.py b/src/server/admin/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/server/admin/models.py b/src/server/admin/models.py new file mode 100644 index 000000000..02f1e35ca --- /dev/null +++ b/src/server/admin/models.py @@ -0,0 +1,113 @@ + +from sqlalchemy import Table, ForeignKey, Column, Integer, String, Boolean, Date, delete, update +from sqlalchemy.orm import relationship +from .._db import Base, session +from typing import Set, Optional, List +from datetime import datetime as dtime + +association_table = Table( + "user_role_link", + Base.metadata, + Column("user_id", ForeignKey("api_user.id")), + Column("role_id", ForeignKey("user_role.id")), +) + + +class User(Base): + __tablename__ = "api_user" + id = Column(Integer, primary_key=True, autoincrement=True) + roles = relationship("UserRole", secondary=association_table) + api_key = Column(String(50), unique=True) + email = Column(String(50), unique=True, nullable=True) + tracking = Column(Boolean, default=True) + registered = Column(Boolean, default=False) + created = Column(Date, default=dtime.strftime(dtime.now(), "%Y-%m-%d")) + last_time_used = Column(Date, default=dtime.strftime(dtime.now(), "%Y-%m-%d")) + + def __init__(self, api_key: str, tracking: bool = True, registered: bool = False) -> None: + self.api_key = api_key + self.tracking = tracking + self.registered = registered + + @staticmethod + def list_users() -> List["User"]: + return session.query(User).all() + + @property + def is_authenticated(self): + return True if self.api_key != "anonymous" else False + + @property + def as_dict(self): + fields_list = ["id", "api_key", "tracking", "registered", "roles"] + user_dict = self.__dict__ + user_dict["roles"] = self.get_user_roles + return {k: v for k, v in user_dict.items() if k in fields_list} + + @property + def get_user_roles(self) -> Set[str]: + return set([role.name for role in self.roles]) + + def has_role(self, required_role: str) -> bool: + return required_role in self.get_user_roles + + @staticmethod + def assign_roles(user: "User", roles: Optional[Set[str]]) -> None: + if roles: + roles_to_assign = session.query(UserRole).filter(UserRole.name.in_(roles)).all() + user.roles = roles_to_assign + session.commit() + else: + user.roles = [] + session.commit() + + @staticmethod + def find_user(user_id: Optional[int] = None, api_key: Optional[str] = None) -> "User": + user = session.query(User).filter((User.id == user_id) | (User.api_key == api_key)).first() + return user if user else User("anonymous") + + @staticmethod + def create_user(api_key: str, user_roles: Optional[Set[str]] = None, tracking: bool = True, registered: bool = False) -> "User": + new_user = User(api_key=api_key, tracking=tracking, registered=registered) + session.add(new_user) + session.commit() + User.assign_roles(new_user, user_roles) + return new_user + + @staticmethod + def update_user(user: "User", api_key: Optional[str], roles: Optional[Set[str]], tracking: Optional[bool], registered: Optional[bool]) -> "User": + user = User.find_user(user_id=user.id) + if user: + update_stmt = update(User).where(User.id == user.id).values(api_key=api_key, tracking=tracking, registered=registered) + session.execute(update_stmt) + session.commit() + User.assign_roles(user, roles) + return user + + @staticmethod + def delete_user(user_id: int) -> None: + session.execute(delete(User).where(User.id == user_id)) + session.commit() + + +class UserRole(Base): + __tablename__ = "user_role" + id = Column(Integer, primary_key=True, autoincrement=True) + name = Column(String(50), unique=True) + + @staticmethod + def create_role(name: str) -> None: + session.execute(f""" + INSERT INTO user_role (name) + SELECT '{name}' + WHERE NOT EXISTS + (SELECT * + FROM user_role + WHERE name='{name}') + """) + session.commit() + + @staticmethod + def list_all_roles(): + roles = session.query(UserRole).all() + return [role.name for role in roles] diff --git a/src/server/admin/templates/index.html b/src/server/admin/templates/index.html new file mode 100644 index 000000000..2a07e02c1 --- /dev/null +++ b/src/server/admin/templates/index.html @@ -0,0 +1,87 @@ + + + + + + API Keys + + + + +
+

+ API Key Admin Interface +

+ {% if flags.banner %} + + {% endif %} + {% if mode == 'overview' %} +

Registered Users

+ + + + + + + + + + + {% for user in users %} + + + + + + + + + {% endfor %} + +
IDAPI KeyRegisteredTrack UsageRolesActions
{{ user.id }}{{ user.api_key }}{{ '✔️' if user.registered else '❌'}}{{ '✔️' if user.tracking else '❌'}}{{ ','.join(user.roles) }} + Edit + Delete +
+

Register New User

+ + {% else %} +

+ < Back + Edit User {{user.id}} +

+ {% endif %} +
+ +
+ + +
+
+ + +
+
+ + +
+
+ + {% for role in roles %} + + {% endfor %} +
+ {% if mode == 'overview' %} + + {% else %} + + + {% endif %} +
+
+ + \ No newline at end of file diff --git a/src/server/admin/templates/request.html b/src/server/admin/templates/request.html new file mode 100644 index 000000000..6493e1ab7 --- /dev/null +++ b/src/server/admin/templates/request.html @@ -0,0 +1,71 @@ + + + + + + API Keys + + + {% if recaptcha_key %} + + {% endif %} + + +
+ {% if mode == 'request' %} +

Register New Delphi Epidata Key

+
+ {% if recaptcha_key %} +
+ {% endif %} + +
+ {% else %} +

Successfully requested a new API key

+

+ Your API key is +

+ + +
+

Authentication options

+

Via request parameter

+

+ The request parameter api_key can be used to pass the api key to the server. Example: +

+

+ http://delphi.cmu.edu/epidata/covidcast/meta?api_key={{api_key}} +

+

Via Basic Authentication

+

+ Another method is providing basic authorization with the username epidata and the password the api key +

+ +
curl -u 'epidata:{{api_key}}' https://delphi.cmu.edu/epidata/covidcast/meta
+
+

Via Bearer Token

+

+ Another method is providing bearer token in the header +

+ +
curl -H 'Authorization: Bearer {{api_key}}' https://delphi.cmu.edu/epidata/covidcast/meta
+
+
+
+

Important Notes

+

+ This API key is rate limited to XXX requests per hour. + In order to lift this limit you need to register your API key using this Register my API key form. +

+
+ {% endif %} +
+ + diff --git a/src/server/endpoints/admin.py b/src/server/endpoints/admin.py new file mode 100644 index 000000000..8678b73bb --- /dev/null +++ b/src/server/endpoints/admin.py @@ -0,0 +1,116 @@ +from pathlib import Path +from typing import Dict, List, Set +from flask import Blueprint, render_template_string, request, make_response +from werkzeug.exceptions import Unauthorized, NotFound, BadRequest +from werkzeug.utils import redirect +from requests import post +from .._security import resolve_auth_token, register_new_key +from .._config import ADMIN_PASSWORD, RECAPTCHA_SECRET_KEY, RECAPTCHA_SITE_KEY, REGISTER_WEBHOOK_TOKEN +from ..admin.models import User, UserRole + + +self_dir = Path(__file__).parent +# first argument is the endpoint name +bp = Blueprint("admin", __name__) + +templates_dir = Path(__file__).parent.parent / "admin" / "templates" + + +def enable_admin() -> bool: + return bool(ADMIN_PASSWORD) + + +def _require_admin(): + token = resolve_auth_token() + if token is None or token != ADMIN_PASSWORD: + raise Unauthorized() + return token + + +def _parse_roles(roles: List[str]) -> Set[str]: + return set(sorted(roles)) + + +def _render(mode: str, token: str, flags: Dict, **kwargs): + template = (templates_dir / "index.html").read_text("utf8") + return render_template_string( + template, mode=mode, token=token, flags=flags, roles=UserRole.list_all_roles(), **kwargs + ) + + +@bp.route("/", methods=["GET", "POST"]) +def _index(): + token = _require_admin() + flags = dict() + if request.method == "POST": + # register a new user + User.create_user( + request.values["api_key"], + _parse_roles(request.values.getlist("roles")), + request.values.get("tracking") == "True", + request.values.get("registered") == "True", + ) + + flags["banner"] = "Successfully Added" + users = [user.as_dict for user in User.list_users()] + return _render("overview", token, flags, users=users, user=dict()) + + +@bp.route("/", methods=["GET", "PUT", "POST", "DELETE"]) +def _detail(user_id: int): + token = _require_admin() + user = User.find_user(user_id=user_id) + if not user: + raise NotFound() + if request.method == "DELETE" or "delete" in request.values: + User.delete_user(user.id) + return redirect(f"./?auth={token}") + flags = dict() + if request.method == "PUT" or request.method == "POST": + user = user.update_user( + user, + request.values["api_key"], + _parse_roles(request.values.getlist("roles")), + request.values.get("tracking") == "True", + request.values.get("registered") == "True", + ) + flags['banner'] = 'Successfully Saved' + return _render("detail", token, flags, user=user.as_dict) + + +@bp.route("/register", methods=["POST"]) +def _register(): + body = request.get_json() + token = body.get("token") + if token is None or token != REGISTER_WEBHOOK_TOKEN: + raise Unauthorized() + + old_api_key = body["user_old_api_key"] + user = User.find_user(api_key=old_api_key) + if user is None: + raise BadRequest("invalid api key") + new_api_key = body["user_new_api_key"] + tracking = True if body["tracking"] == "Yes" else False + user = user.update_user(user, new_api_key, user.roles, tracking, True) + return make_response(f'Successfully registered the API key "{new_api_key}" and removed rate limit', 200) + + +def _verify_recaptcha(): + recaptcha_response = request.values["g-recaptcha-response"] + url = "https://www.google.com/recaptcha/api/siteverify" + # skip remote ip for now since behind proxy + res = post(url, params=dict(secret=RECAPTCHA_SECRET_KEY, response=recaptcha_response)).json() + if res["success"] is not True: + raise BadRequest("invalid recaptcha key") + + +@bp.route("/create_key", methods=["GET", "POST"]) +def _request_api_key(): + template = (templates_dir / "request.html").read_text("utf8") + if request.method == "GET": + return render_template_string(template, mode="request", recaptcha_key=RECAPTCHA_SITE_KEY) + if request.method == "POST": + if RECAPTCHA_SECRET_KEY: + _verify_recaptcha() + api_key = register_new_key() + return render_template_string(template, mode="result", api_key=api_key) diff --git a/src/server/endpoints/afhsb.py b/src/server/endpoints/afhsb.py index 92cee145c..a006defac 100644 --- a/src/server/endpoints/afhsb.py +++ b/src/server/endpoints/afhsb.py @@ -2,10 +2,10 @@ from flask import Blueprint, request -from .._config import AUTH from .._params import extract_integers, extract_strings from .._query import execute_queries, filter_integers, filter_strings -from .._validate import check_auth_token, require_all +from .._validate import require_all +from .._security import require_role # first argument is the endpoint name bp = Blueprint("afhsb", __name__) @@ -53,8 +53,8 @@ def _split_flu_types(flu_types: List[str]): @bp.route("/", methods=("GET", "POST")) +@require_role("afhsb") def handle(): - check_auth_token(request, AUTH["afhsb"]) require_all(request, "locations", "epiweeks", "flu_types") locations = extract_strings("locations") diff --git a/src/server/endpoints/cdc.py b/src/server/endpoints/cdc.py index e89eb94fb..7239402de 100644 --- a/src/server/endpoints/cdc.py +++ b/src/server/endpoints/cdc.py @@ -1,9 +1,10 @@ -from flask import Blueprint, request +from flask import Blueprint -from .._config import AUTH, NATION_REGION, REGION_TO_STATE +from .._config import NATION_REGION, REGION_TO_STATE from .._params import extract_strings, extract_integers from .._query import filter_strings, execute_queries, filter_integers -from .._validate import require_all, check_auth_token +from .._validate import require_all +from .._security import require_role # first argument is the endpoint name bp = Blueprint("cdc", __name__) @@ -11,9 +12,9 @@ @bp.route("/", methods=("GET", "POST")) +@require_role("cdc") def handle(): - check_auth_token(request, AUTH["cdc"]) - require_all(request, "locations", "epiweeks") + require_all("locations", "epiweeks") # parse the request locations = extract_strings("locations") diff --git a/src/server/endpoints/dengue_sensors.py b/src/server/endpoints/dengue_sensors.py index f8286eacd..52d2b231e 100644 --- a/src/server/endpoints/dengue_sensors.py +++ b/src/server/endpoints/dengue_sensors.py @@ -1,9 +1,9 @@ -from flask import Blueprint, request +from flask import Blueprint -from .._config import AUTH from .._params import extract_integers, extract_strings from .._query import execute_query, QueryBuilder -from .._validate import check_auth_token, require_all +from .._validate import require_all +from .._security import require_role # first argument is the endpoint name bp = Blueprint("dengue_sensors", __name__) @@ -11,9 +11,9 @@ @bp.route("/", methods=("GET", "POST")) +@require_role("sensors") def handle(): - check_auth_token(request, AUTH["sensors"]) - require_all(request, "names", "locations", "epiweeks") + require_all("names", "locations", "epiweeks") names = extract_strings("names") locations = extract_strings("locations") @@ -21,7 +21,7 @@ def handle(): # build query q = QueryBuilder("dengue_sensors", "s") - + fields_string = ["name", "location"] fields_int = ["epiweek"] fields_float = ["value"] diff --git a/src/server/endpoints/fluview.py b/src/server/endpoints/fluview.py index 262cbeb27..a1f29d8ae 100644 --- a/src/server/endpoints/fluview.py +++ b/src/server/endpoints/fluview.py @@ -2,7 +2,7 @@ from flask import Blueprint, request -from .._config import AUTH +from .._security import current_user from .._params import ( extract_integer, extract_integers, @@ -10,7 +10,6 @@ ) from .._query import execute_queries, filter_integers, filter_strings from .._validate import ( - check_auth_token, require_all, ) @@ -21,7 +20,7 @@ @bp.route("/", methods=("GET", "POST")) def handle(): - authorized = check_auth_token(request, AUTH["fluview"], optional=True) + authorized = current_user.has_role("fluview") require_all(request, "epiweeks", "regions") diff --git a/src/server/endpoints/ght.py b/src/server/endpoints/ght.py index 24ba84c23..2a32f007e 100644 --- a/src/server/endpoints/ght.py +++ b/src/server/endpoints/ght.py @@ -1,9 +1,9 @@ from flask import Blueprint, request -from .._config import AUTH from .._params import extract_integers, extract_strings from .._query import execute_query, QueryBuilder -from .._validate import check_auth_token, require_all +from .._validate import require_all +from .._security import require_role # first argument is the endpoint name bp = Blueprint("ght", __name__) @@ -11,9 +11,9 @@ @bp.route("/", methods=("GET", "POST")) +@require_role("ght") def handle(): - check_auth_token(request, AUTH["ght"]) - require_all(request, "locations", "epiweeks", "query") + require_all("locations", "epiweeks", "query") locations = extract_strings("locations") epiweeks = extract_integers("epiweeks") diff --git a/src/server/endpoints/meta_afhsb.py b/src/server/endpoints/meta_afhsb.py index 8a74b51ca..096ab58ec 100644 --- a/src/server/endpoints/meta_afhsb.py +++ b/src/server/endpoints/meta_afhsb.py @@ -1,9 +1,9 @@ from flask import Blueprint, request -from .._config import AUTH from .._printer import print_non_standard from .._query import parse_result -from .._validate import check_auth_token +from .._security import require_role + # first argument is the endpoint name bp = Blueprint("meta_afhsb", __name__) @@ -11,9 +11,8 @@ @bp.route("/", methods=("GET", "POST")) +@require_role("afhsb") def handle(): - check_auth_token(request, AUTH["afhsb"]) - # build query table1 = "afhsb_00to13_state" table2 = "afhsb_13to17_state" diff --git a/src/server/endpoints/meta_norostat.py b/src/server/endpoints/meta_norostat.py index 789b09021..ce24de6b4 100644 --- a/src/server/endpoints/meta_norostat.py +++ b/src/server/endpoints/meta_norostat.py @@ -1,9 +1,8 @@ from flask import Blueprint, request -from .._config import AUTH from .._printer import print_non_standard from .._query import parse_result -from .._validate import check_auth_token +from .._security import require_role # first argument is the endpoint name bp = Blueprint("meta_norostat", __name__) @@ -11,9 +10,8 @@ @bp.route("/", methods=("GET", "POST")) +@require_role("norostat") def handle(): - check_auth_token(request, AUTH["norostat"]) - # build query query = "SELECT DISTINCT `release_date` FROM `norostat_raw_datatable_version_list`" releases = parse_result(query, {}, ["release_date"]) diff --git a/src/server/endpoints/norostat.py b/src/server/endpoints/norostat.py index 7dc06d443..28c106bc1 100644 --- a/src/server/endpoints/norostat.py +++ b/src/server/endpoints/norostat.py @@ -1,9 +1,9 @@ from flask import Blueprint, request -from .._config import AUTH from .._params import extract_integers from .._query import execute_query, filter_integers, filter_strings -from .._validate import check_auth_token, require_all +from .._validate import require_all +from .._security import require_role # first argument is the endpoint name bp = Blueprint("norostat", __name__) @@ -11,9 +11,9 @@ @bp.route("/", methods=("GET", "POST")) +@require_role("norostat") def handle(): - check_auth_token(request, AUTH["norostat"]) - require_all(request, "location", "epiweeks") + require_all("location", "epiweeks") location = request.values["location"] epiweeks = extract_integers("epiweeks") diff --git a/src/server/endpoints/quidel.py b/src/server/endpoints/quidel.py index 6de9205b8..64531f9bb 100644 --- a/src/server/endpoints/quidel.py +++ b/src/server/endpoints/quidel.py @@ -1,9 +1,9 @@ -from flask import Blueprint, request +from flask import Blueprint -from .._config import AUTH from .._params import extract_integers, extract_strings from .._query import execute_query, QueryBuilder -from .._validate import check_auth_token, require_all +from .._validate import require_all +from .._security import require_role # first argument is the endpoint name bp = Blueprint("quidel", __name__) @@ -11,9 +11,9 @@ @bp.route("/", methods=("GET", "POST")) +@require_role("quidel") def handle(): - check_auth_token(request, AUTH["quidel"]) - require_all(request, "locations", "epiweeks") + require_all("locations", "epiweeks") locations = extract_strings("locations") epiweeks = extract_integers("epiweeks") diff --git a/src/server/endpoints/sensors.py b/src/server/endpoints/sensors.py index cd76ca4d8..fbd1337e7 100644 --- a/src/server/endpoints/sensors.py +++ b/src/server/endpoints/sensors.py @@ -1,95 +1,43 @@ -from flask import Blueprint, Request, request +from flask import Blueprint, request -from .._config import AUTH, GRANULAR_SENSOR_AUTH_TOKENS, OPEN_SENSORS from .._exceptions import EpiDataException from .._params import ( extract_strings, extract_integers, ) +from .._security import current_user +from .._config import GRANULAR_SENSOR_ROLES, OPEN_SENSORS from .._query import filter_strings, execute_query, filter_integers -from .._validate import ( - require_all, - resolve_auth_token, -) +from .._validate import require_all from typing import List # first argument is the endpoint name bp = Blueprint("sensors", __name__) alias = "signals" -# Limits on the number of effective auth token equality checks performed per sensor query; generate auth tokens with appropriate levels of entropy according to the limits below: -MAX_GLOBAL_AUTH_CHECKS_PER_SENSOR_QUERY = 1 # (but imagine is larger to futureproof) -MAX_GRANULAR_AUTH_CHECKS_PER_SENSOR_QUERY = 30 # (but imagine is larger to futureproof) -# A (currently redundant) limit on the number of auth tokens that can be provided: -MAX_AUTH_KEYS_PROVIDED_PER_SENSOR_QUERY = 1 -# end sensor query authentication configuration - -PHP_INT_MAX = 2147483647 - - -def _authenticate(req: Request, names: List[str]): - auth_tokens_presented = (resolve_auth_token(req) or "").split(",") +def _authenticate(names: List[str]): names = extract_strings("names") n_names = len(names) - n_auth_tokens_presented = len(auth_tokens_presented) - - max_valid_granular_tokens_per_name = max( - len(v) for v in GRANULAR_SENSOR_AUTH_TOKENS.values() - ) # The number of valid granular tokens is related to the number of auth token checks that a single query could perform. Use the max number of valid granular auth tokens per name in the check below as a way to prevent leakage of sensor names (but revealing the number of sensor names) via this interface. Treat all sensors as non-open for convenience of calculation. if n_names == 0: # Check whether no names were provided to prevent edge-case issues in error message below, and in case surrounding behavior changes in the future: raise EpiDataException("no sensor names provided") - if n_auth_tokens_presented > 1: - raise EpiDataException( - "currently, only a single auth token is allowed to be presented at a time; please issue a separate query for each sensor name using only the corresponding token" - ) - - # Check whether max number of presented-vs.-acceptable token comparisons that would be performed is over the set limits, avoiding calculation of numbers > PHP_INT_MAX/100: - # Global auth token comparison limit check: - # Granular auth token comparison limit check: - if ( - n_auth_tokens_presented > MAX_GLOBAL_AUTH_CHECKS_PER_SENSOR_QUERY - or n_names - > int((PHP_INT_MAX / 100 - 1) / max(1, max_valid_granular_tokens_per_name)) - or n_auth_tokens_presented - > int(PHP_INT_MAX / 100 / max(1, n_names * max_valid_granular_tokens_per_name)) - or n_auth_tokens_presented * n_names * max_valid_granular_tokens_per_name - > MAX_GRANULAR_AUTH_CHECKS_PER_SENSOR_QUERY - ): - raise EpiDataException( - "too many sensors requested and/or auth tokens presented; please divide sensors into batches and/or use only the tokens needed for the sensors requested" - ) - - if len(auth_tokens_presented) > MAX_AUTH_KEYS_PROVIDED_PER_SENSOR_QUERY: - # this check should be redundant with >1 check as well as global check above - raise EpiDataException("too many auth tokens presented") unauthenticated_or_nonexistent_sensors = [] for name in names: sensor_is_open = name in OPEN_SENSORS # test whether they provided the "global" auth token that works for all sensors: - sensor_authenticated_globally = AUTH["sensors"] in auth_tokens_presented + sensor_authenticated_globally = current_user.has_role("sensors") # test whether they provided a "granular" auth token for one of the # sensor_subsets containing this sensor (if any): sensor_authenticated_granularly = False - if name in GRANULAR_SENSOR_AUTH_TOKENS: - acceptable_granular_tokens_for_sensor = GRANULAR_SENSOR_AUTH_TOKENS[name] - # check for nonempty intersection between provided and acceptable - # granular auth tokens: - for acceptable_granular_token in acceptable_granular_tokens_for_sensor: - if acceptable_granular_token in auth_tokens_presented: - sensor_authenticated_granularly = True - break + if name in GRANULAR_SENSOR_ROLES and current_user.has_role(GRANULAR_SENSOR_ROLES[name]): + sensor_authenticated_granularly = True # (else: there are no granular tokens for this sensor; can't authenticate granularly) - if ( - not sensor_is_open - and not sensor_authenticated_globally - and not sensor_authenticated_granularly - ): + if not sensor_is_open and not sensor_authenticated_globally and not sensor_authenticated_granularly: # authentication failed for this sensor; append to list: unauthenticated_or_nonexistent_sensors.append(name) diff --git a/src/server/endpoints/twitter.py b/src/server/endpoints/twitter.py index 84cbb2850..0a96adfd0 100644 --- a/src/server/endpoints/twitter.py +++ b/src/server/endpoints/twitter.py @@ -1,16 +1,16 @@ from flask import Blueprint, request -from .._config import AUTH, NATION_REGION, REGION_TO_STATE +from .._config import NATION_REGION, REGION_TO_STATE from .._params import ( extract_integers, extract_strings, ) from .._query import execute_queries, filter_dates, filter_integers, filter_strings from .._validate import ( - check_auth_token, require_all, require_any, ) +from .._security import require_role # first argument is the endpoint name bp = Blueprint("twitter", __name__) @@ -18,8 +18,8 @@ @bp.route("/", methods=("GET", "POST")) +@require_role("twitter") def handle(): - check_auth_token(request, AUTH["twitter"]) require_all(request, "locations") require_any(request, "dates", "epiweeks") diff --git a/src/server/main.py b/src/server/main.py index 7471a2491..72d7ba04c 100644 --- a/src/server/main.py +++ b/src/server/main.py @@ -4,10 +4,14 @@ from flask import request, send_file, Response, send_from_directory, jsonify -from ._config import URL_PREFIX, VERSION +from ._config import URL_PREFIX, VERSION, RATE_LIMIT +from ._db import metadata, engine from ._common import app, set_compatibility_mode from ._exceptions import MissingOrWrongSourceException from .endpoints import endpoints +from .endpoints.admin import bp as admin_bp, enable_admin +from ._security import limiter, deduct_on_success, register_user_role +from ._config import UserRole __all__ = ["app"] @@ -15,14 +19,23 @@ for endpoint in endpoints: endpoint_map[endpoint.bp.name] = endpoint.handle + limiter.limit(RATE_LIMIT, deduct_when=deduct_on_success)(endpoint.bp) app.register_blueprint(endpoint.bp, url_prefix=f"{URL_PREFIX}/{endpoint.bp.name}") - alias = getattr(endpoint, "alias", None) if alias: endpoint_map[alias] = endpoint.handle +with app.app_context(): + for role in UserRole: + register_user_role(role.name) + +if enable_admin(): + limiter.exempt(admin_bp) + app.register_blueprint(admin_bp, url_prefix=f"{URL_PREFIX}/admin") + @app.route(f"{URL_PREFIX}/api.php", methods=["GET", "POST"]) +@limiter.limit(RATE_LIMIT, deduct_when=deduct_on_success) def handle_generic(): # mark as compatibility mode set_compatibility_mode() @@ -49,6 +62,8 @@ def send_lib_file(path: str): return send_from_directory(pathlib.Path(__file__).parent / "lib", path) +metadata.create_all(engine) + if __name__ == "__main__": app.run(host="0.0.0.0", port=5000) else: @@ -57,4 +72,6 @@ def send_lib_file(path: str): app.logger.handlers = gunicorn_logger.handlers app.logger.setLevel(gunicorn_logger.level) sqlalchemy_logger = logging.getLogger("sqlalchemy") - sqlalchemy_logger.setLevel(logging.WARN) + sqlalchemy_logger.handlers = gunicorn_logger.handlers + sqlalchemy_logger.setLevel(logging.ERROR) + #sqlalchemy_logger.setLevel(gunicorn_logger.level) diff --git a/tests/server/test_security.py b/tests/server/test_security.py new file mode 100644 index 000000000..e209d9342 --- /dev/null +++ b/tests/server/test_security.py @@ -0,0 +1,47 @@ +"""Unit tests for granular sensor authentication in api.php.""" + +# standard library +import unittest +import base64 + +# from flask.testing import FlaskClient +from delphi.epidata.server._common import app +from delphi.epidata.server._security import ( + resolve_auth_token, +) + +# py3tester coverage target +__test_target__ = "delphi.epidata.server._security" + + +class UnitTests(unittest.TestCase): + """Basic unit tests.""" + + # app: FlaskClient + + def setUp(self): + app.config["TESTING"] = True + app.config["WTF_CSRF_ENABLED"] = False + app.config["DEBUG"] = False + + def test_resolve_auth_token(self): + with self.subTest("no auth"): + with app.test_request_context("/"): + self.assertIsNone(resolve_auth_token()) + + with self.subTest("param"): + with app.test_request_context("/?auth=abc"): + self.assertEqual(resolve_auth_token(), "abc") + + with self.subTest("param2"): + with app.test_request_context("/?api_key=abc"): + self.assertEqual(resolve_auth_token(), "abc") + + with self.subTest("bearer token"): + with app.test_request_context("/", headers={"Authorization": "Bearer abc"}): + self.assertEqual(resolve_auth_token(), "abc") + + with self.subTest("basic token"): + userpass = base64.b64encode(b"epidata:abc").decode("utf-8") + with app.test_request_context("/", headers={"Authorization": f"Basic {userpass}"}): + self.assertEqual(resolve_auth_token(), "abc") diff --git a/tests/server/test_validate.py b/tests/server/test_validate.py index 22a4f153c..27ce28672 100644 --- a/tests/server/test_validate.py +++ b/tests/server/test_validate.py @@ -2,21 +2,17 @@ # standard library import unittest -import base64 from flask import request # from flask.testing import FlaskClient from delphi.epidata.server._common import app from delphi.epidata.server._validate import ( - resolve_auth_token, - check_auth_token, require_all, require_any, ) from delphi.epidata.server._exceptions import ( ValidationFailedException, - UnAuthenticatedException, ) # py3tester coverage target @@ -33,47 +29,6 @@ def setUp(self): app.config["WTF_CSRF_ENABLED"] = False app.config["DEBUG"] = False - def test_resolve_auth_token(self): - with self.subTest("no auth"): - with app.test_request_context("/"): - self.assertIsNone(resolve_auth_token(request)) - - with self.subTest("param"): - with app.test_request_context("/?auth=abc"): - self.assertEqual(resolve_auth_token(request), "abc") - - with self.subTest("bearer token"): - with app.test_request_context("/", headers={"Authorization": "Bearer abc"}): - self.assertEqual(resolve_auth_token(request), "abc") - - with self.subTest("basic token"): - userpass = base64.b64encode(b"epidata:abc").decode("utf-8") - with app.test_request_context( - "/", headers={"Authorization": f"Basic {userpass}"} - ): - self.assertEqual(resolve_auth_token(request), "abc") - - def test_check_auth_token(self): - with self.subTest("no auth but optional"): - with app.test_request_context("/"): - self.assertFalse(check_auth_token(request, "abc", True)) - with self.subTest("no auth but required"): - with app.test_request_context("/"): - self.assertRaises( - ValidationFailedException, lambda: check_auth_token(request, "abc") - ) - with self.subTest("auth and required"): - with app.test_request_context("/?auth=abc"): - self.assertTrue(check_auth_token(request, "abc")) - with self.subTest("auth and required but wrong"): - with app.test_request_context("/?auth=abc"): - self.assertRaises( - UnAuthenticatedException, lambda: check_auth_token(request, "def") - ) - with self.subTest("auth and required but wrong but optional"): - with app.test_request_context("/?auth=abc"): - self.assertFalse(check_auth_token(request, "def", True)) - def test_require_all(self): with self.subTest("all given"): with app.test_request_context("/"):