Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 32 additions & 5 deletions patches/kaggle_secrets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
from datetime import datetime, timedelta
from enum import Enum, unique
import subprocess
from typing import Optional, Tuple
from kaggle_web_client import KaggleWebClient
from kaggle_web_client import (CredentialError, BackendError)
Expand Down Expand Up @@ -80,6 +81,28 @@ def get_gcloud_credential(self) -> str:
else:
raise

def set_gcloud_credentials(self, project=None, account=None):
"""Set user credentials attached to the current kernel and optionally the project & account name to the `gcloud` CLI.

Example usage:
client = UserSecretsClient()
client.set_gcloud_credentials(project="my-gcp-project", account="[email protected]")

!gcloud ai-platform jobs list
"""
creds = self.get_gcloud_credential()
creds_path = self._write_credentials_file(creds)

subprocess.run(['gcloud', 'config', 'set', 'auth/credential_file_override', creds_path])

if project:
os.environ['GOOGLE_CLOUD_PROJECT'] = project
subprocess.run(['gcloud', 'config', 'set', 'project', project])

if account:
os.environ['GOOGLE_ACCOUNT'] = account
subprocess.run(['gcloud', 'config', 'set', 'account', account])

def set_tensorflow_credential(self, credential):
"""Sets the credential for use by Tensorflow both in the local notebook
and to pass to the TPU.
Expand All @@ -89,11 +112,7 @@ def set_tensorflow_credential(self, credential):

# Write to a local JSON credentials file and set
# GOOGLE_APPLICATION_CREDENTIALS for tensorflow running in the notebook.
adc_path = os.path.join(
os.environ.get('HOME', '/'), 'gcloud_credential.json')
with open(adc_path, 'w') as f:
f.write(credential)
os.environ['GOOGLE_APPLICATION_CREDENTIALS']=adc_path
self._write_credentials_file(credential)

# set the credential for the TPU
tensorflow_gcs_config.configure_gcs(credentials=credential)
Expand All @@ -108,6 +127,14 @@ def get_bigquery_access_token(self) -> Tuple[str, Optional[datetime]]:
token, expiry = client.get_bigquery_access_token()
"""
return self._get_access_token(GcpTarget.BIGQUERY)

def _write_credentials_file(self, credentials) -> str:
adc_path = os.path.join(os.environ.get('HOME', '/'), 'gcloud_credential.json')
with open(adc_path, 'w') as f:
f.write(credentials)
os.environ['GOOGLE_APPLICATION_CREDENTIALS']=adc_path

return adc_path

def _get_gcs_access_token(self) -> Tuple[str, Optional[datetime]]:
return self._get_access_token(GcpTarget.GCS)
Expand Down
29 changes: 29 additions & 0 deletions tests/test_user_secrets.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import subprocess
import threading
import unittest
from http.server import BaseHTTPRequestHandler, HTTPServer
Expand Down Expand Up @@ -134,6 +135,34 @@ def call_get_secret():
'/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__"},
success=False)

def test_set_gcloud_credentials_succeeds(self):
secret = '{"client_id":"gcloud","type":"authorized_user"}'
project = 'foo'
account = 'bar'

def get_gcloud_config_value(field):
result = subprocess.run(['gcloud', 'config', 'get-value', field], capture_output=True)
result.check_returncode()
return result.stdout.strip().decode('ascii')

def test_fn():
client = UserSecretsClient()
client.set_gcloud_credentials(project=project, account=account)

self.assertEqual(project, os.environ['GOOGLE_CLOUD_PROJECT'])
self.assertEqual(project, get_gcloud_config_value('project'))

self.assertEqual(account, os.environ['GOOGLE_ACCOUNT'])
self.assertEqual(account, get_gcloud_config_value('account'))

expected_creds_file = '/tmp/gcloud_credential.json'
self.assertEqual(expected_creds_file, os.environ['GOOGLE_APPLICATION_CREDENTIALS'])
self.assertEqual(expected_creds_file, get_gcloud_config_value('auth/credential_file_override'))

with open(expected_creds_file, 'r') as f:
self.assertEqual(secret, '\n'.join(f.readlines()))

self._test_client(test_fn, '/requests/GetUserSecretByLabelRequest', {'Label': "__gcloud_sdk_auth__"}, secret=secret)

@mock.patch('kaggle_secrets.datetime')
def test_get_access_token_succeeds(self, mock_dt):
Expand Down