diff --git a/patches/kaggle_secrets.py b/patches/kaggle_secrets.py index a52896b0..18bd8988 100644 --- a/patches/kaggle_secrets.py +++ b/patches/kaggle_secrets.py @@ -7,6 +7,7 @@ import os from datetime import datetime, timedelta from enum import Enum, unique +import subprocess from typing import Optional, Tuple from kaggle_web_client import KaggleWebClient from kaggle_web_client import (CredentialError, BackendError) @@ -80,6 +81,28 @@ def get_gcloud_credential(self) -> str: else: raise + def set_gcloud_credentials(self, project=None, account=None): + """Set user credentials attached to the current kernel and optionally the project & account name to the `gcloud` CLI. + + Example usage: + client = UserSecretsClient() + client.set_gcloud_credentials(project="my-gcp-project", account="me@my-org.com") + + !gcloud ai-platform jobs list + """ + creds = self.get_gcloud_credential() + creds_path = self._write_credentials_file(creds) + + subprocess.run(['gcloud', 'config', 'set', 'auth/credential_file_override', creds_path]) + + if project: + os.environ['GOOGLE_CLOUD_PROJECT'] = project + subprocess.run(['gcloud', 'config', 'set', 'project', project]) + + if account: + os.environ['GOOGLE_ACCOUNT'] = account + subprocess.run(['gcloud', 'config', 'set', 'account', account]) + def set_tensorflow_credential(self, credential): """Sets the credential for use by Tensorflow both in the local notebook and to pass to the TPU. @@ -89,11 +112,7 @@ def set_tensorflow_credential(self, credential): # Write to a local JSON credentials file and set # GOOGLE_APPLICATION_CREDENTIALS for tensorflow running in the notebook. - adc_path = os.path.join( - os.environ.get('HOME', '/'), 'gcloud_credential.json') - with open(adc_path, 'w') as f: - f.write(credential) - os.environ['GOOGLE_APPLICATION_CREDENTIALS']=adc_path + self._write_credentials_file(credential) # set the credential for the TPU tensorflow_gcs_config.configure_gcs(credentials=credential) @@ -108,6 +127,14 @@ def get_bigquery_access_token(self) -> Tuple[str, Optional[datetime]]: token, expiry = client.get_bigquery_access_token() """ return self._get_access_token(GcpTarget.BIGQUERY) + + def _write_credentials_file(self, credentials) -> str: + adc_path = os.path.join(os.environ.get('HOME', '/'), 'gcloud_credential.json') + with open(adc_path, 'w') as f: + f.write(credentials) + os.environ['GOOGLE_APPLICATION_CREDENTIALS']=adc_path + + return adc_path def _get_gcs_access_token(self) -> Tuple[str, Optional[datetime]]: return self._get_access_token(GcpTarget.GCS) diff --git a/tests/test_user_secrets.py b/tests/test_user_secrets.py index 6b0c5ffe..46dd6d8b 100644 --- a/tests/test_user_secrets.py +++ b/tests/test_user_secrets.py @@ -1,5 +1,6 @@ import json import os +import subprocess import threading import unittest from http.server import BaseHTTPRequestHandler, HTTPServer @@ -134,6 +135,34 @@ def call_get_secret(): '/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__"}, success=False) + def test_set_gcloud_credentials_succeeds(self): + secret = '{"client_id":"gcloud","type":"authorized_user"}' + project = 'foo' + account = 'bar' + + def get_gcloud_config_value(field): + result = subprocess.run(['gcloud', 'config', 'get-value', field], capture_output=True) + result.check_returncode() + return result.stdout.strip().decode('ascii') + + def test_fn(): + client = UserSecretsClient() + client.set_gcloud_credentials(project=project, account=account) + + self.assertEqual(project, os.environ['GOOGLE_CLOUD_PROJECT']) + self.assertEqual(project, get_gcloud_config_value('project')) + + self.assertEqual(account, os.environ['GOOGLE_ACCOUNT']) + self.assertEqual(account, get_gcloud_config_value('account')) + + expected_creds_file = '/tmp/gcloud_credential.json' + self.assertEqual(expected_creds_file, os.environ['GOOGLE_APPLICATION_CREDENTIALS']) + self.assertEqual(expected_creds_file, get_gcloud_config_value('auth/credential_file_override')) + + with open(expected_creds_file, 'r') as f: + self.assertEqual(secret, '\n'.join(f.readlines())) + + self._test_client(test_fn, '/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__"}, secret=secret) @mock.patch('kaggle_secrets.datetime') def test_get_access_token_succeeds(self, mock_dt):