diff --git a/cirro/api/auth/__init__.py b/cirro/api/auth/__init__.py index b381bf76..73e492ef 100644 --- a/cirro/api/auth/__init__.py +++ b/cirro/api/auth/__init__.py @@ -1,3 +1,5 @@ +from io import StringIO +from typing import Optional from cirro.api.auth.base import AuthInfo from cirro.api.auth.iam import IAMAuth from cirro.api.auth.oauth_client import ClientAuth @@ -12,12 +14,13 @@ from cirro.api.config import AppConfig -def get_auth_info_from_config(app_config: AppConfig): +def get_auth_info_from_config(app_config: AppConfig, auth_io: Optional[StringIO] = None): user_config = app_config.user_config if not user_config or not user_config.auth_method: return ClientAuth(region=app_config.region, client_id=app_config.client_id, - auth_endpoint=app_config.auth_endpoint) + auth_endpoint=app_config.auth_endpoint, + auth_io=auth_io) auth_methods = [ ClientAuth, @@ -33,7 +36,8 @@ def get_auth_info_from_config(app_config: AppConfig): return ClientAuth(region=app_config.region, client_id=app_config.client_id, auth_endpoint=app_config.auth_endpoint, - enable_cache=auth_config.get('enable_cache') == 'True') + enable_cache=auth_config.get('enable_cache') == 'True', + auth_io=auth_io) if matched_auth_method == IAMAuth and auth_config.get('load_current'): return IAMAuth.load_current() diff --git a/cirro/api/auth/oauth_client.py b/cirro/api/auth/oauth_client.py index c956afbf..7dc0c848 100644 --- a/cirro/api/auth/oauth_client.py +++ b/cirro/api/auth/oauth_client.py @@ -1,3 +1,4 @@ +from io import StringIO import json import logging import sys @@ -42,12 +43,15 @@ def _build_token_persistence(location, fallback_to_plaintext=False): return FilePersistence(location) -def _authenticate(client_id: str, auth_endpoint: str): +def _authenticate(client_id: str, auth_endpoint: str, auth_io: Optional[StringIO] = None): params = {'client_id': client_id} resp = requests.post(f'{auth_endpoint}/device-code', params=params) resp.raise_for_status() flow: DeviceTokenResponse = resp.json() - print(flow['message']) + if auth_io is None: + print(flow['message']) + else: + auth_io.write(flow['message']) device_expiry = datetime.fromisoformat(flow['expiry']) params = { @@ -83,7 +87,14 @@ class ClientAuth(AuthInfo): Implements the OAuth device code flow This is the preferred way to authenticate """ - def __init__(self, client_id: str, region: str, auth_endpoint: str, enable_cache=True): + def __init__( + self, + client_id: str, + region: str, + auth_endpoint: str, + enable_cache=True, + auth_io: Optional[StringIO] = None + ): self.client_id = client_id self.region = region self._token_info: Optional[OAuthTokenResponse] = None @@ -107,7 +118,7 @@ def __init__(self, client_id: str, region: str, auth_endpoint: str, enable_cache self._clear_token_info() if not self._token_info: - self._token_info = _authenticate(client_id=client_id, auth_endpoint=auth_endpoint) + self._token_info = _authenticate(client_id=client_id, auth_endpoint=auth_endpoint, auth_io=auth_io) self._save_token_info() self._update_token_metadata() diff --git a/cirro/api/clients/portal.py b/cirro/api/clients/portal.py index a9cb240f..34154671 100644 --- a/cirro/api/clients/portal.py +++ b/cirro/api/clients/portal.py @@ -4,16 +4,17 @@ from cirro.api.clients import ApiClient from cirro.api.config import AppConfig from cirro.api.services import DatasetService, ProcessService, ProjectService, FileService, CommonService +from io import StringIO class DataPortalClient: """ A client for interacting with the Cirro platform """ - def __init__(self, auth_info: Optional[AuthInfo] = None, base_url: str = None): + def __init__(self, auth_info: Optional[AuthInfo] = None, base_url: str = None, auth_io: Optional[StringIO] = None): self._configuration = AppConfig(base_url=base_url) if not auth_info: - auth_info = get_auth_info_from_config(self._configuration) + auth_info = get_auth_info_from_config(self._configuration, auth_io=auth_io) self._api_client = ApiClient(auth_info, data_endpoint=self._configuration.data_endpoint) self._file_service = FileService(self._api_client, self._configuration) diff --git a/pyproject.toml b/pyproject.toml index 7a468e11..e8ce24f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "cirro" -version = "0.7.2" +version = "0.7.3" description = "CLI tool and SDK for interacting with the Cirro platform" authors = ["Cirro Bio "] license = "MIT"