Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions patches/kaggle_web_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,12 @@ def __init__(self):
f'but none found in environment variable {_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME}')
self.headers = {
'Content-type': 'application/json',
'Authorization': f'Bearer {self.jwt_token}',
'X-Kaggle-Authorization': f'Bearer {self.jwt_token}',
}

def make_post_request(self, data: dict, endpoint: str, timeout: int = TIMEOUT_SECS) -> dict:
url = f'{self.url_base}{endpoint}'
request_body = dict(data)
request_body['JWE'] = self.jwt_token
req = urllib.request.Request(url, headers=self.headers, data=bytes(
json.dumps(request_body), encoding="utf-8"))
try:
Expand Down
14 changes: 7 additions & 7 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def get_response(self):
msg="Fake server did not receive a Content-Type header from the KaggleDatasets client.")
self.assertEqual('application/json', headers.get('Content-Type'),
msg="Fake server did not receive an application/json content type header from the KaggleDatasets client.")
self.assertIn('Authorization', headers.keys(),
msg="Fake server did not receive an Authorization header from the KaggleDatasets client.")
self.assertEqual(f'Bearer {_TEST_JWT}', headers.get('Authorization'),
msg="Fake server did not receive the right Authorization header from the KaggleDatasets client.")
self.assertIn('X-Kaggle-Authorization', headers.keys(),
msg="Fake server did not receive an X-Kaggle-Authorization header from the KaggleDatasets client.")
self.assertEqual(f'Bearer {_TEST_JWT}', headers.get('X-Kaggle-Authorization'),
msg="Fake server did not receive the right X-Kaggle-Authorization header from the KaggleDatasets client.")

def test_no_token_fails(self):
env = EnvironmentVarGuard()
Expand All @@ -104,7 +104,7 @@ def call_get_gcs_path():
self.assertEqual(gcs_path, _TPU_GCS_BUCKET)
self._test_client(call_get_gcs_path,
'/requests/CopyDatasetVersionToKnownGcsBucketRequest',
{'MountSlug': None, 'IntegrationType': 2, 'JWE': _TEST_JWT},
{'MountSlug': None, 'IntegrationType': 2},
is_tpu=True)

def test_get_gcs_path_automl_succeeds(self):
Expand All @@ -114,7 +114,7 @@ def call_get_gcs_path():
self.assertEqual(gcs_path, _AUTOML_GCS_BUCKET)
self._test_client(call_get_gcs_path,
'/requests/CopyDatasetVersionToKnownGcsBucketRequest',
{'MountSlug': None, 'IntegrationType': 1, 'JWE': _TEST_JWT},
{'MountSlug': None, 'IntegrationType': 1},
is_tpu=False)

def test_get_gcs_path_handles_unsuccessful(self):
Expand All @@ -124,6 +124,6 @@ def call_get_gcs_path():
gcs_path = client.get_gcs_path()
self._test_client(call_get_gcs_path,
'/requests/CopyDatasetVersionToKnownGcsBucketRequest',
{'MountSlug': None, 'IntegrationType': 2, 'JWE': _TEST_JWT},
{'MountSlug': None, 'IntegrationType': 2},
is_tpu=True,
success=False)
14 changes: 7 additions & 7 deletions tests/test_user_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def call_get_secret():
secret_response = client.get_secret("secret_label")
self.assertEqual(secret_response, secret)
self._test_client(call_get_secret,
'/requests/GetUserSecretByLabelRequest', {'Label': "secret_label", 'JWE': _TEST_JWT},
'/requests/GetUserSecretByLabelRequest', {'Label': "secret_label"},
secret=secret)

def test_get_secret_handles_unsuccessful(self):
Expand All @@ -103,7 +103,7 @@ def call_get_secret():
with self.assertRaises(BackendError):
secret_response = client.get_secret("secret_label")
self._test_client(call_get_secret,
'/requests/GetUserSecretByLabelRequest', {'Label': "secret_label", 'JWE': _TEST_JWT},
'/requests/GetUserSecretByLabelRequest', {'Label': "secret_label"},
success=False)

def test_get_secret_validates_label(self):
Expand All @@ -122,7 +122,7 @@ def call_get_secret():
secret_response = client.get_gcloud_credential()
self.assertEqual(secret_response, secret)
self._test_client(call_get_secret,
'/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__", 'JWE': _TEST_JWT},
'/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__"},
secret=secret)

def test_get_gcloud_secret_handles_unsuccessful(self):
Expand All @@ -131,7 +131,7 @@ def call_get_secret():
with self.assertRaises(NotFoundError):
secret_response = client.get_gcloud_credential()
self._test_client(call_get_secret,
'/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__", 'JWE': _TEST_JWT},
'/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__"},
success=False)


Expand All @@ -150,10 +150,10 @@ def call_get_gcs_access_token():
secret_response = client._get_gcs_access_token()
self.assertEqual(secret_response, (secret, now + timedelta(seconds=3600)))
self._test_client(call_get_bigquery_access_token,
'/requests/GetUserSecretRequest', {'Target': GcpTarget.BIGQUERY.target, 'JWE': _TEST_JWT},
'/requests/GetUserSecretRequest', {'Target': GcpTarget.BIGQUERY.target},
secret=secret)
self._test_client(call_get_gcs_access_token,
'/requests/GetUserSecretRequest', {'Target': GcpTarget.GCS.target, 'JWE': _TEST_JWT},
'/requests/GetUserSecretRequest', {'Target': GcpTarget.GCS.target},
secret=secret)

def test_get_access_token_handles_unsuccessful(self):
Expand All @@ -162,4 +162,4 @@ def call_get_access_token():
with self.assertRaises(BackendError):
client.get_bigquery_access_token()
self._test_client(call_get_access_token,
'/requests/GetUserSecretRequest', {'Target': GcpTarget.BIGQUERY.target, 'JWE': _TEST_JWT}, success=False)
'/requests/GetUserSecretRequest', {'Target': GcpTarget.BIGQUERY.target}, success=False)