99from kaggle_web_client import (KaggleWebClient ,
1010 _KAGGLE_URL_BASE_ENV_VAR_NAME ,
1111 _KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME ,
12+ _KAGGLE_IAP_TOKEN_ENV_VAR_NAME ,
1213 CredentialError , BackendError )
1314from kaggle_datasets import KaggleDatasets , _KAGGLE_TPU_NAME_ENV_VAR_NAME
1415
1516_TEST_JWT = 'test-secrets-key'
17+ _TEST_IAP = 'IAP_TOKEN'
1618
1719_TPU_GCS_BUCKET = 'gs://kds-tpu-ea1971a458ffd4cd51389e7574c022ecc0a82bb1b52ccef08c8a'
1820_AUTOML_GCS_BUCKET = 'gs://kds-automl-ea1971a458ffd4cd51389e7574c022ecc0a82bb1b52ccef08c8a'
@@ -39,7 +41,7 @@ def do_POST(s):
3941class TestDatasets (unittest .TestCase ):
4042 SERVER_ADDRESS = urlparse (os .getenv (_KAGGLE_URL_BASE_ENV_VAR_NAME , default = "http://127.0.0.1:8001" ))
4143
42- def _test_client (self , client_func , expected_path , expected_body , is_tpu = True , success = True ):
44+ def _test_client (self , client_func , expected_path , expected_body , is_tpu = True , success = True , iap_token = False ):
4345 _request = {}
4446
4547 class GetGcsPathHandler (GcsDatasetsHTTPHandler ):
@@ -63,6 +65,8 @@ def get_response(self):
6365 env .set (_KAGGLE_USER_SECRETS_TOKEN_ENV_VAR_NAME , _TEST_JWT )
6466 if is_tpu :
6567 env .set (_KAGGLE_TPU_NAME_ENV_VAR_NAME , 'FAKE_TPU' )
68+ if iap_token :
69+ env .set (_KAGGLE_IAP_TOKEN_ENV_VAR_NAME , _TEST_IAP )
6670 with env :
6771 with HTTPServer ((self .SERVER_ADDRESS .hostname , self .SERVER_ADDRESS .port ), GetGcsPathHandler ) as httpd :
6872 threading .Thread (target = httpd .serve_forever ).start ()
@@ -87,6 +91,12 @@ def get_response(self):
8791 msg = "Fake server did not receive an application/json content type header from the KaggleDatasets client." )
8892 self .assertIn ('X-Kaggle-Authorization' , headers .keys (),
8993 msg = "Fake server did not receive an X-Kaggle-Authorization header from the KaggleDatasets client." )
94+ if iap_token :
95+ self .assertEqual (f'Bearer { _TEST_IAP } ' , headers .get ('Authorization' ),
96+ msg = "Fake server did not receive an Authorization header from the KaggleDatasets client." )
97+ else :
98+ self .assertNotIn ('Authorization' , headers .keys (),
99+ msg = "Fake server received an Authorization header from the KaggleDatasets client. It shouldn't." )
90100 self .assertEqual (f'Bearer { _TEST_JWT } ' , headers .get ('X-Kaggle-Authorization' ),
91101 msg = "Fake server did not receive the right X-Kaggle-Authorization header from the KaggleDatasets client." )
92102
@@ -127,3 +137,12 @@ def call_get_gcs_path():
127137 {'MountSlug' : None , 'IntegrationType' : 2 },
128138 is_tpu = True ,
129139 success = False )
140+
141+ def test_iap_token (self ):
142+ def call_get_gcs_path ():
143+ client = KaggleDatasets ()
144+ gcs_path = client .get_gcs_path ()
145+ self ._test_client (call_get_gcs_path ,
146+ '/requests/CopyDatasetVersionToKnownGcsBucketRequest' ,
147+ {'MountSlug' : None , 'IntegrationType' : 1 },
148+ is_tpu = False , iap_token = True )
0 commit comments