diff --git a/patches/kaggle_web_client.py b/patches/kaggle_web_client.py index 66d8c13c..f7b7ae8b 100644 --- a/patches/kaggle_web_client.py +++ b/patches/kaggle_web_client.py @@ -7,6 +7,7 @@ _KAGGLE_DEFAULT_URL_BASE = "https://www.kaggle.com" _KAGGLE_URL_BASE_ENV_VAR_NAME = "KAGGLE_URL_BASE" _KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME = "KAGGLE_USER_SECRETS_TOKEN" +_KAGGLE_IAP_TOKEN_ENV_VAR_NAME = "KAGGLE_IAP_TOKEN" TIMEOUT_SECS = 40 class CredentialError(Exception): @@ -32,6 +33,9 @@ def __init__(self): 'Content-type': 'application/json', 'X-Kaggle-Authorization': f'Bearer {self.jwt_token}', } + iap_token = os.getenv(_KAGGLE_IAP_TOKEN_ENV_VAR_NAME) + if iap_token: + self.headers['Authorization'] = f'Bearer {iap_token}' def make_post_request(self, data: dict, endpoint: str, timeout: int = TIMEOUT_SECS) -> dict: url = f'{self.url_base}{endpoint}' diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 02ffee6e..029a4570 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -9,10 +9,12 @@ from kaggle_web_client import (KaggleWebClient, _KAGGLE_URL_BASE_ENV_VAR_NAME, _KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME, + _KAGGLE_IAP_TOKEN_ENV_VAR_NAME, CredentialError, BackendError) from kaggle_datasets import KaggleDatasets, _KAGGLE_TPU_NAME_ENV_VAR_NAME _TEST_JWT = 'test-secrets-key' +_TEST_IAP = 'IAP_TOKEN' _TPU_GCS_BUCKET = 'gs://kds-tpu-ea1971a458ffd4cd51389e7574c022ecc0a82bb1b52ccef08c8a' _AUTOML_GCS_BUCKET = 'gs://kds-automl-ea1971a458ffd4cd51389e7574c022ecc0a82bb1b52ccef08c8a' @@ -39,7 +41,7 @@ def do_POST(s): class TestDatasets(unittest.TestCase): SERVER_ADDRESS = urlparse(os.getenv(_KAGGLE_URL_BASE_ENV_VAR_NAME, default="http://127.0.0.1:8001")) - def _test_client(self, client_func, expected_path, expected_body, is_tpu=True, success=True): + def _test_client(self, client_func, expected_path, expected_body, is_tpu=True, success=True, iap_token=False): _request = {} class GetGcsPathHandler(GcsDatasetsHTTPHandler): @@ -63,6 +65,8 @@ def get_response(self): env.set(_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME, _TEST_JWT) if is_tpu: env.set(_KAGGLE_TPU_NAME_ENV_VAR_NAME, 'FAKE_TPU') + if iap_token: + env.set(_KAGGLE_IAP_TOKEN_ENV_VAR_NAME, _TEST_IAP) with env: with HTTPServer((self.SERVER_ADDRESS.hostname, self.SERVER_ADDRESS.port), GetGcsPathHandler) as httpd: threading.Thread(target=httpd.serve_forever).start() @@ -87,6 +91,12 @@ def get_response(self): msg="Fake server did not receive an application/json content type header from the KaggleDatasets client.") self.assertIn('X-Kaggle-Authorization', headers.keys(), msg="Fake server did not receive an X-Kaggle-Authorization header from the KaggleDatasets client.") + if iap_token: + self.assertEqual(f'Bearer {_TEST_IAP}', headers.get('Authorization'), + msg="Fake server did not receive an Authorization header from the KaggleDatasets client.") + else: + self.assertNotIn('Authorization', headers.keys(), + msg="Fake server received an Authorization header from the KaggleDatasets client. It shouldn't.") self.assertEqual(f'Bearer {_TEST_JWT}', headers.get('X-Kaggle-Authorization'), msg="Fake server did not receive the right X-Kaggle-Authorization header from the KaggleDatasets client.") @@ -127,3 +137,12 @@ def call_get_gcs_path(): {'MountSlug': None, 'IntegrationType': 2}, is_tpu=True, success=False) + + def test_iap_token(self): + def call_get_gcs_path(): + client = KaggleDatasets() + gcs_path = client.get_gcs_path() + self._test_client(call_get_gcs_path, + '/requests/CopyDatasetVersionToKnownGcsBucketRequest', + {'MountSlug': None, 'IntegrationType': 1}, + is_tpu=False, iap_token=True)