Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions cirro/api/auth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from io import StringIO
from typing import Optional
from cirro.api.auth.base import AuthInfo
from cirro.api.auth.iam import IAMAuth
from cirro.api.auth.oauth_client import ClientAuth
Expand All @@ -12,12 +14,13 @@
from cirro.api.config import AppConfig


def get_auth_info_from_config(app_config: AppConfig):
def get_auth_info_from_config(app_config: AppConfig, auth_io: Optional[StringIO] = None):
user_config = app_config.user_config
if not user_config or not user_config.auth_method:
return ClientAuth(region=app_config.region,
client_id=app_config.client_id,
auth_endpoint=app_config.auth_endpoint)
auth_endpoint=app_config.auth_endpoint,
auth_io=auth_io)

auth_methods = [
ClientAuth,
Expand All @@ -33,7 +36,8 @@ def get_auth_info_from_config(app_config: AppConfig):
return ClientAuth(region=app_config.region,
client_id=app_config.client_id,
auth_endpoint=app_config.auth_endpoint,
enable_cache=auth_config.get('enable_cache') == 'True')
enable_cache=auth_config.get('enable_cache') == 'True',
auth_io=auth_io)

if matched_auth_method == IAMAuth and auth_config.get('load_current'):
return IAMAuth.load_current()
Expand Down
19 changes: 15 additions & 4 deletions cirro/api/auth/oauth_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from io import StringIO
import json
import logging
import sys
Expand Down Expand Up @@ -42,12 +43,15 @@ def _build_token_persistence(location, fallback_to_plaintext=False):
return FilePersistence(location)


def _authenticate(client_id: str, auth_endpoint: str):
def _authenticate(client_id: str, auth_endpoint: str, auth_io: Optional[StringIO] = None):
params = {'client_id': client_id}
resp = requests.post(f'{auth_endpoint}/device-code', params=params)
resp.raise_for_status()
flow: DeviceTokenResponse = resp.json()
print(flow['message'])
if auth_io is None:
print(flow['message'])
else:
auth_io.write(flow['message'])
device_expiry = datetime.fromisoformat(flow['expiry'])

params = {
Expand Down Expand Up @@ -83,7 +87,14 @@ class ClientAuth(AuthInfo):
Implements the OAuth device code flow
This is the preferred way to authenticate
"""
def __init__(self, client_id: str, region: str, auth_endpoint: str, enable_cache=True):
def __init__(
self,
client_id: str,
region: str,
auth_endpoint: str,
enable_cache=True,
auth_io: Optional[StringIO] = None
):
self.client_id = client_id
self.region = region
self._token_info: Optional[OAuthTokenResponse] = None
Expand All @@ -107,7 +118,7 @@ def __init__(self, client_id: str, region: str, auth_endpoint: str, enable_cache
self._clear_token_info()

if not self._token_info:
self._token_info = _authenticate(client_id=client_id, auth_endpoint=auth_endpoint)
self._token_info = _authenticate(client_id=client_id, auth_endpoint=auth_endpoint, auth_io=auth_io)

self._save_token_info()
self._update_token_metadata()
Expand Down
5 changes: 3 additions & 2 deletions cirro/api/clients/portal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@
from cirro.api.clients import ApiClient
from cirro.api.config import AppConfig
from cirro.api.services import DatasetService, ProcessService, ProjectService, FileService, CommonService
from io import StringIO


class DataPortalClient:
"""
A client for interacting with the Cirro platform
"""
def __init__(self, auth_info: Optional[AuthInfo] = None, base_url: str = None):
def __init__(self, auth_info: Optional[AuthInfo] = None, base_url: str = None, auth_io: Optional[StringIO] = None):
self._configuration = AppConfig(base_url=base_url)
if not auth_info:
auth_info = get_auth_info_from_config(self._configuration)
auth_info = get_auth_info_from_config(self._configuration, auth_io=auth_io)

self._api_client = ApiClient(auth_info, data_endpoint=self._configuration.data_endpoint)
self._file_service = FileService(self._api_client, self._configuration)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "cirro"
version = "0.7.2"
version = "0.7.3"
description = "CLI tool and SDK for interacting with the Cirro platform"
authors = ["Cirro Bio <[email protected]>"]
license = "MIT"
Expand Down