diff --git a/cirro/cli/interactive/auth_args.py b/cirro/cli/interactive/auth_args.py index 2359abdf..6072f328 100644 --- a/cirro/cli/interactive/auth_args.py +++ b/cirro/cli/interactive/auth_args.py @@ -1,9 +1,8 @@ import logging -import re from typing import Tuple, Dict from cirro.cli.interactive.utils import ask_yes_no, ask -from cirro.config import list_tenants +from cirro.config import list_tenants, extract_base_url logger = logging.getLogger() @@ -21,8 +20,8 @@ def gather_auth_config() -> Tuple[str, str, Dict, bool]: choices=[tenant['domain'] for tenant in tenant_options], meta_information={tenant['domain']: tenant['displayName'] for tenant in tenant_options} ) - # remove http(s):// if it's there - base_url = re.compile(r'https?://').sub('', base_url).strip('/').strip() + # Fix user-provided base URL, if necessary + base_url = extract_base_url(base_url) auth_method_config = { 'enable_cache': ask_yes_no('Would you like to save your login? (do not use this on shared devices)') diff --git a/cirro/config.py b/cirro/config.py index 332b2116..633305c1 100644 --- a/cirro/config.py +++ b/cirro/config.py @@ -25,6 +25,14 @@ class UserConfig(NamedTuple): enable_additional_checksum: Optional[bool] +def extract_base_url(base_url: str): + # remove http(s):// if it's there + base_url = re.compile(r'https?://').sub('', base_url).strip() + # remove anything after the first slash if it's there + base_url = base_url.split('/')[0] + return base_url + + def list_tenants() -> List[Tenant]: resp = requests.get(f'https://nexus.{Constants.default_base_url}/info') resp.raise_for_status() diff --git a/tests/test_config_load.py b/tests/test_config_load.py index 52aff3ff..9ad6e6a0 100644 --- a/tests/test_config_load.py +++ b/tests/test_config_load.py @@ -1,14 +1,27 @@ import unittest -from cirro.config import AppConfig, list_tenants +from cirro.config import AppConfig, list_tenants, extract_base_url + +TEST_BASE_URL = "app.cirro.bio" class TestConfigLoad(unittest.TestCase): def test_config_load(self): - config = AppConfig(base_url="app.cirro.bio") + config = AppConfig(base_url=TEST_BASE_URL) self.assertIsNotNone(config.client_id) self.assertIsNotNone(config.auth_endpoint) def test_list_tenants(self): tenants = list_tenants() self.assertGreater(len(tenants), 1) + + def test_extract_base(self): + test_cases = [ + f"https://{TEST_BASE_URL}", + TEST_BASE_URL, + f"https://{TEST_BASE_URL}/projects", + f"{TEST_BASE_URL}/asd/", + ] + for test_case in test_cases: + with self.subTest(test_case): + self.assertEqual(TEST_BASE_URL, extract_base_url(test_case))