77import os
88from datetime import datetime , timedelta
99from enum import Enum , unique
10+ import subprocess
1011from typing import Optional , Tuple
1112from kaggle_web_client import KaggleWebClient
1213from kaggle_web_client import (CredentialError , BackendError )
@@ -80,6 +81,28 @@ def get_gcloud_credential(self) -> str:
8081 else :
8182 raise
8283
84+ def set_gcloud_credentials (self , project = None , account = None ):
85+ """Set user credentials attached to the current kernel and optionally the project & account name to the `gcloud` CLI.
86+
87+ Example usage:
88+ client = UserSecretsClient()
89+ client.set_gcloud_credentials(project="my-gcp-project", account="[email protected] ") 90+
91+ !gcloud ai-platform jobs list
92+ """
93+ creds = self .get_gcloud_credential ()
94+ creds_path = self ._write_credentials_file (creds )
95+
96+ subprocess .run (['gcloud' , 'config' , 'set' , 'auth/credential_file_override' , creds_path ])
97+
98+ if project :
99+ os .environ ['GOOGLE_CLOUD_PROJECT' ] = project
100+ subprocess .run (['gcloud' , 'config' , 'set' , 'project' , project ])
101+
102+ if account :
103+ os .environ ['GOOGLE_ACCOUNT' ] = account
104+ subprocess .run (['gcloud' , 'config' , 'set' , 'account' , account ])
105+
83106 def set_tensorflow_credential (self , credential ):
84107 """Sets the credential for use by Tensorflow both in the local notebook
85108 and to pass to the TPU.
@@ -89,11 +112,7 @@ def set_tensorflow_credential(self, credential):
89112
90113 # Write to a local JSON credentials file and set
91114 # GOOGLE_APPLICATION_CREDENTIALS for tensorflow running in the notebook.
92- adc_path = os .path .join (
93- os .environ .get ('HOME' , '/' ), 'gcloud_credential.json' )
94- with open (adc_path , 'w' ) as f :
95- f .write (credential )
96- os .environ ['GOOGLE_APPLICATION_CREDENTIALS' ]= adc_path
115+ self ._write_credentials_file (credential )
97116
98117 # set the credential for the TPU
99118 tensorflow_gcs_config .configure_gcs (credentials = credential )
@@ -108,6 +127,14 @@ def get_bigquery_access_token(self) -> Tuple[str, Optional[datetime]]:
108127 token, expiry = client.get_bigquery_access_token()
109128 """
110129 return self ._get_access_token (GcpTarget .BIGQUERY )
130+
131+ def _write_credentials_file (self , credentials ) -> str :
132+ adc_path = os .path .join (os .environ .get ('HOME' , '/' ), 'gcloud_credential.json' )
133+ with open (adc_path , 'w' ) as f :
134+ f .write (credentials )
135+ os .environ ['GOOGLE_APPLICATION_CREDENTIALS' ]= adc_path
136+
137+ return adc_path
111138
112139 def _get_gcs_access_token (self ) -> Tuple [str , Optional [datetime ]]:
113140 return self ._get_access_token (GcpTarget .GCS )
0 commit comments