From 1246b28e9e36dd2e5c615b79b36629e99b30719e Mon Sep 17 00:00:00 2001 From: Elisha Ben Yosef Date: Wed, 6 Nov 2024 14:27:04 +0200 Subject: [PATCH 1/8] Updated code to support ssh into sagemaker space apps - JupyterLab and CodeEditor 1. Added Dockerfile to build custom image 2. Updated script to ssh into sagemaker space apps --- sagemaker_ssh_helper/ide.py | 24 +++++++++++- sagemaker_ssh_helper/manager.py | 13 +++++++ sagemaker_ssh_helper/sm-connect-ssh-proxy | 6 +-- sagemaker_ssh_helper/sm-helper-functions | 11 ++++-- sagemaker_ssh_helper/sm-init-ssm | 6 ++- sagemaker_ssh_helper/sm-local-ssh-ide | 37 ++++++++++++++++--- sagemaker_ssh_helper/sm-ssh-ide | 33 +++++++++++------ .../Dockerfile.codeeditor.internet_free | 35 ++++++++++++++++++ 8 files changed, 139 insertions(+), 26 deletions(-) create mode 100644 tests/byoi_studio/Dockerfile.codeeditor.internet_free diff --git a/sagemaker_ssh_helper/ide.py b/sagemaker_ssh_helper/ide.py index 6f05a8e..93b7f6f 100644 --- a/sagemaker_ssh_helper/ide.py +++ b/sagemaker_ssh_helper/ide.py @@ -46,9 +46,10 @@ def __init__(self, arn, version_arn) -> None: class SSHIDE: logger = logging.getLogger('sagemaker-ssh-helper:SSHIDE') - def __init__(self, domain_id: str, user: str, region_name: str = None): + def __init__(self, domain_id: str, user: str, region_name: str = None, space: str = None): self.user = user self.domain_id = domain_id + self.space = space self.current_region = region_name or boto3.session.Session().region_name self.client = boto3.client('sagemaker', region_name=self.current_region) self.ssh_log = SSHLog(region_name=self.current_region) @@ -202,6 +203,9 @@ def resolve_sagemaker_kernel_image_arn(self, image_name): def print_kernel_instance_id(self, app_name, timeout_in_sec, index: int = 0): print(self.get_kernel_instance_id(app_name, timeout_in_sec, index)) + def print_space_instance_id(self, app_name, timeout_in_sec, index: int = 0): + print(self.get_space_instance_ids(app_name, timeout_in_sec)[index]) + def get_kernel_instance_id(self, app_name, timeout_in_sec, index: int = 0, not_earlier_than_timestamp: int = 0): ids = self.get_kernel_instance_ids(app_name, timeout_in_sec, not_earlier_than_timestamp) @@ -228,6 +232,24 @@ def get_kernel_instance_ids(self, app_name: str, timeout_in_sec: int, not_earlie result = SSMManager().get_studio_kgw_instance_ids(app_name, timeout_in_sec, not_earlier_than_timestamp) return result + def get_space_instance_ids(self, app_name, timeout_in_sec): + self.logger.info("Resolving IDE instance IDs through SSM tags") + self.log_urls(app_name) + if self.domain_id and self.space: + result = SSMManager().get_studio_space_app_instance_ids(self.domain_id, self.space, app_name, timeout_in_sec) + elif self.space: + self.logger.warning(f"Domain ID is not set. Will attempt to connect to the latest " + f"active kernel gateway with the name {app_name} in the region {self.current_region} " + f"for space {self.space}") + result = SSMManager().get_studio_space_app_instance_ids("", self.space, app_name, + timeout_in_sec) + else: + self.logger.warning(f"Domain ID or space name are not set. Will attempt to connect to the latest " + f"active kernel gateway with the name {app_name} in the region {self.current_region}") + result = SSMManager().get_studio_app_instance_ids(app_name, timeout_in_sec) + return result + + def log_urls(self, app_name): self.logger.info(f"Remote logs are at {self.get_cloudwatch_url(app_name)}") if self.domain_id and self.user: diff --git a/sagemaker_ssh_helper/manager.py b/sagemaker_ssh_helper/manager.py index 5ea2aaa..9a70764 100644 --- a/sagemaker_ssh_helper/manager.py +++ b/sagemaker_ssh_helper/manager.py @@ -113,6 +113,15 @@ def get_transformer_instance_ids(self, transform_job_name, timeout_in_sec=0): self.logger.info(f"Querying SSM instance IDs for transform job {transform_job_name}") return self.get_instance_ids('transform-job', transform_job_name, timeout_in_sec) + def get_studio_space_app_instance_ids(self, domain_id, space_name, app_name, timeout_in_sec=0): + self.logger.info(f"Querying SSM instance IDs for SageMaker Studio space {app_name}") + if not domain_id: + arn_filter = f":app/.*/{space_name}/" + else: + arn_filter = f":app/{domain_id}/{space_name}/" + return self.get_instance_ids('app', f"{app_name}", timeout_in_sec, + arn_filter_regex=arn_filter) + def get_studio_user_kgw_instance_ids(self, domain_id, user_profile_name, kgw_name, timeout_in_sec=0, not_earlier_than_timestamp: int = 0): self.logger.info(f"Querying SSM instance IDs for SageMaker Studio kernel gateway: '{kgw_name}'") @@ -129,6 +138,10 @@ def get_studio_kgw_instance_ids(self, kgw_name, timeout_in_sec=0, not_earlier_th return self.get_instance_ids('app', f"{kgw_name}", timeout_in_sec, not_earlier_than_timestamp) + def get_studio_app_instance_ids(self, app_name, timeout_in_sec=0): + self.logger.info(f"Querying SSM instance IDs for SageMaker Studio space {app_name}") + return self.get_instance_ids('app', f"{app_name}", timeout_in_sec) + def get_notebook_instance_ids(self, instance_name, timeout_in_sec=0): self.logger.info(f"Querying SSM instance IDs for SageMaker notebook instance {instance_name}") return self.get_instance_ids('notebook-instance', f"{instance_name}", diff --git a/sagemaker_ssh_helper/sm-connect-ssh-proxy b/sagemaker_ssh_helper/sm-connect-ssh-proxy index f4cf70f..70742e9 100644 --- a/sagemaker_ssh_helper/sm-connect-ssh-proxy +++ b/sagemaker_ssh_helper/sm-connect-ssh-proxy @@ -94,7 +94,7 @@ send_command=$(aws ssm send-command \ 'cat /etc/ssh/authorized_keys.d/* > /etc/ssh/authorized_keys', 'ls -la /etc/ssh/authorized_keys' ]" \ - --no-cli-pager --no-paginate \ + --no-paginate \ --output json) json_value_regexp='s/^[^"]*".*": \"\(.*\)\"[^"]*/\1/' @@ -114,7 +114,7 @@ for i in $(seq 1 15); do command_output=$(aws ssm get-command-invocation \ --instance-id "${INSTANCE_ID}" \ --command-id "${command_id}" \ - --no-cli-pager --no-paginate \ + --no-paginate \ --output json) command_output=$(echo "$command_output" | $(_python) -m json.tool) command_status=$(echo "$command_output" | grep '"Status":' | sed -e "$json_value_regexp") @@ -166,7 +166,7 @@ proxy_command="aws ssm start-session\ --parameters portNumber=%p" # shellcheck disable=SC2086 -ssh -4 -o User=root -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \ +ssh -4 -o User=sagemaker-user -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \ -o ProxyCommand="$proxy_command" \ -o ConnectTimeout=90 \ -o ServerAliveInterval=15 -o ServerAliveCountMax=3 \ diff --git a/sagemaker_ssh_helper/sm-helper-functions b/sagemaker_ssh_helper/sm-helper-functions index 704b22d..e8f8090 100644 --- a/sagemaker_ssh_helper/sm-helper-functions +++ b/sagemaker_ssh_helper/sm-helper-functions @@ -194,9 +194,14 @@ function _print_sm_domain_id() { # shellcheck disable=SC2001 function _print_sm_user_profile_name() { - # FIXME: Check for "SpaceName" - spaces are not supported yet sm_resource_metadata_json=$(tr -d "\n" < /opt/ml/metadata/resource-metadata.json) - echo -n "$sm_resource_metadata_json" | sed -e 's/^.*"UserProfileName":\"\([^"]*\)\".*$/\1/' + echo -n "$sm_resource_metadata_json" | jq -r '.UserProfileName' +} + +# shellcheck disable=SC2001 +function _print_sm_space_name() { + sm_resource_metadata_json=$(tr -d "\n" < /opt/ml/metadata/resource-metadata.json) + echo -n "$sm_resource_metadata_json" | jq -r '.SpaceName' } function _print_sm_studio_python() { @@ -242,7 +247,7 @@ function _start_sshd() { if _is_centos; then /usr/sbin/sshd else - service ssh start || (echo "ERROR: Failed to start sshd service" && exit 255) + sudo service ssh start || (echo "ERROR: Failed to start sshd service" && exit 255) fi } diff --git a/sagemaker_ssh_helper/sm-init-ssm b/sagemaker_ssh_helper/sm-init-ssm index 9d6f4a0..6edc15c 100644 --- a/sagemaker_ssh_helper/sm-init-ssm +++ b/sagemaker_ssh_helper/sm-init-ssm @@ -59,4 +59,8 @@ response=$(aws ssm create-activation \ acode=$(echo $response | jq --raw-output '.ActivationCode') aid=$(echo $response | jq --raw-output '.ActivationId') -echo Yes | amazon-ssm-agent -register -id "$aid" -code "$acode" -region "$CURRENT_REGION" +if [[ -n $(_print_sm_user_profile_name) && $(_print_sm_user_profile_name) != "null" ]]; then + echo Yes | amazon-ssm-agent -register -id "$aid" -code "$acode" -region "$CURRENT_REGION" +else + echo Yes | sudo amazon-ssm-agent -register -id "$aid" -code "$acode" -region "$CURRENT_REGION" +fi diff --git a/sagemaker_ssh_helper/sm-local-ssh-ide b/sagemaker_ssh_helper/sm-local-ssh-ide index b3e6585..a877ef2 100644 --- a/sagemaker_ssh_helper/sm-local-ssh-ide +++ b/sagemaker_ssh_helper/sm-local-ssh-ide @@ -39,6 +39,7 @@ echo "sm-local-ssh-ide: Starting in $dir" DOMAIN_ID="" USER_PROFILE_NAME="" +SPACE_NAME="" if [[ "$1" == "--domain-id" ]]; then DOMAIN_ID=$2 @@ -68,6 +69,12 @@ if [[ "$COMMAND" != "set-user-profile-name" && "$COMMAND" != "set-domain-id" \ echo "sm-local-ssh-ide: WARNING: SageMaker Studio user profile name is not set."\ "Run 'sm-local-ssh-ide set-user-profile-name' to override." fi + if [ -f ~/.sm-studio-space-name ]; then + SPACE_NAME="$(cat ~/.sm-studio-space-name)" + else + echo "sm-local-ssh-ide: WARNING: SageMaker Studio space name is not set."\ + "Run 'sm-local-ssh-ide set-space-name' to override." + fi fi if [[ "$COMMAND" == "proxy-host" ]]; then @@ -87,12 +94,21 @@ elif [[ "$COMMAND" == "connect-app" ]]; then OPTIONS="$3" # shellcheck disable=SC2091 - INSTANCE_ID=$($(_python) < ~/.sm-studio-user-profile-name +elif [[ "$COMMAND" == "set-space-name" ]]; then + SPACE_NAME="$(echo "$2" | tr '[:upper:]' '[:lower:]')" + if [[ "$SPACE_NAME" == "" ]]; then + echo "sm-local-ssh-ide: ERROR: argument is expected" + exit 1 + fi + echo "sm-local-ssh-ide: Saving SageMaker Studio user profile name into ~/.sm-studio-space-name" + echo "$SPACE_NAME" > ~/.sm-studio-space-name + elif [[ "$COMMAND" == "run-command" ]]; then shift diff --git a/sagemaker_ssh_helper/sm-ssh-ide b/sagemaker_ssh_helper/sm-ssh-ide index 9b080c1..a45ac76 100644 --- a/sagemaker_ssh_helper/sm-ssh-ide +++ b/sagemaker_ssh_helper/sm-ssh-ide @@ -34,10 +34,10 @@ if [[ "$1" == "configure" ]]; then cat >/etc/profile.d/sm-ssh-ide.sh < Date: Wed, 6 Nov 2024 22:26:59 +0200 Subject: [PATCH 2/8] Add support for Sagemaker space --- sagemaker_ssh_helper/ide.py | 50 +++++++++++-------- sagemaker_ssh_helper/interactive_sagemaker.py | 24 ++++----- sagemaker_ssh_helper/sm-local-ssh-ide | 26 ++++++---- sagemaker_ssh_helper/sm_ssh.py | 30 ++++++++--- tests/test_ssm_manager.py | 36 ++++--------- 5 files changed, 91 insertions(+), 75 deletions(-) diff --git a/sagemaker_ssh_helper/ide.py b/sagemaker_ssh_helper/ide.py index 93b7f6f..7abd073 100644 --- a/sagemaker_ssh_helper/ide.py +++ b/sagemaker_ssh_helper/ide.py @@ -46,7 +46,7 @@ def __init__(self, arn, version_arn) -> None: class SSHIDE: logger = logging.getLogger('sagemaker-ssh-helper:SSHIDE') - def __init__(self, domain_id: str, user: str, region_name: str = None, space: str = None): + def __init__(self, domain_id: str, user: str = None, region_name: str = None, space: str = None): self.user = user self.domain_id = domain_id self.space = space @@ -109,13 +109,17 @@ def get_app_status(self, app_name: str, app_type: str = 'KernelGateway') -> IDEA :return: None | 'InService' | 'Deleted' | 'Deleting' | 'Failed' | 'Pending' """ response = None + + describe_app_params = { + "DomainId": self.domain_id, + "AppType": app_type, + "AppName": app_name, + } + + describe_app_params.update({"UserProfileName": self.user} if self.user else {"SpaceName": self.space}) + try: - response = self.client.describe_app( - DomainId=self.domain_id, - AppType=app_type, - UserProfileName=self.user, - AppName=app_name, - ) + response = self.client.describe_app(**describe_app_params) except ClientError as e: error_code = e.response.get("Error", {}).get("Code") if error_code == 'ResourceNotFound': @@ -138,12 +142,15 @@ def delete_app(self, app_name, app_type, wait: bool = True): self.logger.info(f"Deleting app {app_name}") try: - _ = self.client.delete_app( - DomainId=self.domain_id, - AppType=app_type, - UserProfileName=self.user, - AppName=app_name, - ) + delete_app_params = { + "DomainId": self.domain_id, + "AppType": app_type, + "AppName": app_name, + } + + delete_app_params.update({"UserProfileName": self.user} if self.user else {"SpaceName": self.space}) + + _ = self.client.delete_app(**delete_app_params) except ClientError as e: # probably, already deleted code = e.response.get("Error", {}).get("Code") @@ -174,13 +181,16 @@ def create_app(self, app_name, app_type, instance_type, image_arn, if lifecycle_arn: resource_spec['LifecycleConfigArn'] = lifecycle_arn - _ = self.client.create_app( - DomainId=self.domain_id, - AppType=app_type, - AppName=app_name, - UserProfileName=self.user, - ResourceSpec=resource_spec, - ) + create_app_params = { + "DomainId": self.domain_id, + "AppType": app_type, + "AppName": app_name, + "ResourceSpec": resource_spec, + } + + create_app_params.update({"UserProfileName": self.user} if self.user else {"SpaceName": self.space}) + + _ = self.client.create_app(**create_app_params) status = self.get_app_status(app_name) while status.is_pending(): self.logger.info(f"Waiting for the InService status. Current status: {status}") diff --git a/sagemaker_ssh_helper/interactive_sagemaker.py b/sagemaker_ssh_helper/interactive_sagemaker.py index 73a48d2..21fd695 100644 --- a/sagemaker_ssh_helper/interactive_sagemaker.py +++ b/sagemaker_ssh_helper/interactive_sagemaker.py @@ -31,15 +31,16 @@ def set_ping_status(self, ping_status): class SageMakerStudioApp(SageMakerCoreApp): - def __init__(self, domain_id: str, user_profile_name: str, app_name: str, app_type: str, - app_status: IDEAppStatus) -> None: + def __init__(self, domain_id: str, app_name: str, app_type: str, app_status: IDEAppStatus, user_profile_name: str = None, + space_name: str = None) -> None: super().__init__() self.app_status = app_status self.app_type = app_type self.app_name = app_name self.user_profile_name = user_profile_name + self.space_name = space_name self.domain_id = domain_id - self.resource_type = "ide" + self.resource_type = "ide" if user_profile_name else "ide-space" def __str__(self) -> str: return "{0:<16} {1:<18} {2:<12} {5}.{4}.{3}.{6}".format( @@ -47,7 +48,7 @@ def __str__(self) -> str: self.app_type, str(self.app_status), self.domain_id, - self.user_profile_name, + self.user_profile_name if self.user_profile_name else self.space_name, self.app_name, SageMakerSecureShellHelper.type_to_fqdn(self.resource_type) ) @@ -160,17 +161,15 @@ def list_ide_apps(self) -> List[SageMakerStudioApp]: app_name = app_dict['AppName'] app_type = app_dict['AppType'] if 'SpaceName' in app_dict: - logging.info("Don't support spaces: skipping app %s of type %s" % (app_name, app_type)) - pass + space_name = app_dict['SpaceName'] + logging.info("Found app %s of type %s for space %s" % (app_name, app_type, space_name)) + app_status = SSHIDE(domain_id, None, self.region, space_name).get_app_status(app_name, app_type) + result.append(SageMakerStudioApp(domain_id, app_dict['AppName'], app_dict['AppType'], app_status, space_name=space_name)) elif app_type in ['JupyterServer', 'KernelGateway']: user_profile_name = app_dict['UserProfileName'] logging.info("Found app %s of type %s for user %s" % (app_name, app_type, user_profile_name)) app_status = SSHIDE(domain_id, user_profile_name, self.region).get_app_status(app_name, app_type) - result.append(SageMakerStudioApp( - domain_id, user_profile_name, - app_dict['AppName'], app_dict['AppType'], - app_status - )) + result.append(SageMakerStudioApp(domain_id, app_dict['AppName'], app_dict['AppType'], app_status, user_profile_name=user_profile_name)) else: logging.info("Unsupported app type %s" % app_type) pass # We don't support other types like 'DetailedProfiler' @@ -342,7 +341,8 @@ def _find_latest_app_instance_id(managed_instances: Dict[str, Dict[str, str]], s arn = tags['SSHResourceArn'] if 'SSHResourceArn' in tags else '' timestamp = int(tags['SSHTimestamp']) if 'SSHTimestamp' in tags else 0 if (':app/' in arn and arn.endswith(f"/{sagemaker_app.app_name}") - and f"/{sagemaker_app.user_profile_name}/" in arn + and (f"/{sagemaker_app.user_profile_name}/" in arn or + f"/{sagemaker_app.space_name}/" in arn) and f"/{sagemaker_app.domain_id}/" in arn and timestamp > max_timestamp): result = managed_instance_id diff --git a/sagemaker_ssh_helper/sm-local-ssh-ide b/sagemaker_ssh_helper/sm-local-ssh-ide index a877ef2..b8c084b 100644 --- a/sagemaker_ssh_helper/sm-local-ssh-ide +++ b/sagemaker_ssh_helper/sm-local-ssh-ide @@ -53,6 +53,12 @@ if [[ "$1" == "--user-profile-name" ]]; then shift fi +if [[ "$1" == "--space-name" ]]; then + SPACE_NAME=$2 + shift + shift +fi + COMMAND=$1 if [[ "$COMMAND" != "set-user-profile-name" && "$COMMAND" != "set-domain-id" \ @@ -95,18 +101,18 @@ elif [[ "$COMMAND" == "connect-app" ]]; then # shellcheck disable=SC2091 if [ -z "$SPACE_NAME" ]; then - INSTANCE_ID=$(python < str: if fqdn.endswith(".studio.sagemaker") or fqdn == "studio.sagemaker": return "ide" + elif fqdn.endswith(".space.sagemaker") or fqdn == "space.sagemaker": + return "ide-space" elif fqdn.endswith(".notebook.sagemaker") or fqdn == "notebook.sagemaker": return "notebook" elif fqdn.endswith(".training.sagemaker") or fqdn == "training.sagemaker": @@ -38,6 +40,8 @@ def fqdn_to_type(fqdn: str) -> str: def type_to_fqdn(cls, resource_type): if resource_type == "ide": return "studio.sagemaker" + elif resource_type == "ide-space": + return "space.sagemaker" elif resource_type == "notebook": return "notebook.sagemaker" elif resource_type == "training": @@ -76,20 +80,30 @@ def fqdn_to_studio_user_name(fqdn: str) -> str: else: return '' + fqdn_to_studio_space_name = fqdn_to_studio_user_name + @classmethod def _get_arguments(cls, fqdn, resource, command): - domain_id = "" user_profile_name = "" - if resource == "ide": + space_name = "" + + if resource.startswith("ide"): domain_id = SageMakerSecureShellHelper.fqdn_to_studio_domain_id(fqdn) - user_profile_name = SageMakerSecureShellHelper.fqdn_to_studio_user_name(fqdn) - if domain_id and user_profile_name: - arguments = ["bash", f"sm-local-ssh-{resource}", - "--domain-id", domain_id, "--user-profile-name", user_profile_name] + if resource.endswith("space"): + space_name = SageMakerSecureShellHelper.fqdn_to_studio_space_name(fqdn) + else: + user_profile_name = SageMakerSecureShellHelper.fqdn_to_studio_user_name(fqdn) + + arguments = ["bash", "sm-local-ssh-ide", "--domain-id", domain_id] + + if domain_id and user_profile_name: + arguments.extend(["--user-profile-name", user_profile_name]) + elif domain_id and space_name: + arguments.extend(["--space-name", space_name]) else: arguments = ["bash", f"sm-local-ssh-{resource}"] - if resource == "ide": + if resource.startswith("ide"): arguments.append(cls._sm_ssh_command_to_local_ssh_command(command, "app")) elif resource == "notebook": arguments.append(cls._sm_ssh_command_to_local_ssh_command(command, "notebook")) @@ -137,7 +151,7 @@ def list(self, fqdn): for resource in self.resources: if resource_type == resource or resource_type == "all": # if-then-else branch for every resource type: - if resource == "ide": + if resource.startswith("ide"): domain_id = SageMakerSecureShellHelper.fqdn_to_studio_domain_id(fqdn) user_profile_name = SageMakerSecureShellHelper.fqdn_to_studio_user_name(fqdn) interactive_sagemaker.print_studio_ide_apps_for_user_and_domain(domain_id, user_profile_name) diff --git a/tests/test_ssm_manager.py b/tests/test_ssm_manager.py index 0766925..a4cb2fc 100644 --- a/tests/test_ssm_manager.py +++ b/tests/test_ssm_manager.py @@ -310,31 +310,17 @@ def test_can_list_ssh_and_non_ssh_instances(): sagemaker = SageMaker('eu-west-1') sagemaker.list_ide_apps = Mock(return_value=[ - SageMakerStudioApp( - "d-0123456789bc", "janedoe", "default", "JupyterServer", IDEAppStatus("InService") - ), - SageMakerStudioApp( - "d-0123456789bc", "janedoe", "ssh-test-kgw", "KernelGateway", IDEAppStatus("InService") - ), - SageMakerStudioApp( - "d-0123456789bc", "janedoe", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("InService") - ), - SageMakerStudioApp( - "d-0123456789bc", "janedoe", "data-science-g4-no-ssh", "KernelGateway", IDEAppStatus("Offline") - ), - SageMakerStudioApp( - "d-0123456789bc", "terry", "sagemaker-data-science-ml-m5-large-1234567890abcdef0", "KernelGateway", IDEAppStatus("Offline") - ), - SageMakerStudioApp( - "d-0123456789bc", "terry", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("Offline") - ), - - SageMakerStudioApp( - "d-0123456789ab", "terry", "sagemaker-data-science-ml-m5-large-1234567890abcdef0", "KernelGateway", IDEAppStatus("Offline") - ), - SageMakerStudioApp( - "d-0123456789ab", "terry", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("Offline") - ), + SageMakerStudioApp("d-0123456789bc", "default", "JupyterServer", IDEAppStatus("InService"), "janedoe"), + SageMakerStudioApp("d-0123456789bc", "ssh-test-kgw", "KernelGateway", IDEAppStatus("InService"), "janedoe"), + SageMakerStudioApp("d-0123456789bc", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("InService"), "janedoe"), + SageMakerStudioApp("d-0123456789bc", "data-science-g4-no-ssh", "KernelGateway", IDEAppStatus("Offline"), "janedoe"), + SageMakerStudioApp("d-0123456789bc", "sagemaker-data-science-ml-m5-large-1234567890abcdef0", "KernelGateway", + IDEAppStatus("Offline"), "terry"), + SageMakerStudioApp("d-0123456789bc", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("Offline"), "terry"), + + SageMakerStudioApp("d-0123456789ab", "sagemaker-data-science-ml-m5-large-1234567890abcdef0", "KernelGateway", + IDEAppStatus("Offline"), "terry"), + SageMakerStudioApp("d-0123456789ab", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("Offline"), "terry"), # LocalApp("janedoe", "AIDACKCEVSQ6C2EXAMPLE:janedoe@SSO", "macOS 13.5.1"), # LocalApp("terry", "AIDACKCEVSQ6C2EXAMPLE:terry@SSO", "Windows 10 Pro"), ]) From 2664427e391f73875d0a45e43fd3d536977468dd Mon Sep 17 00:00:00 2001 From: Elisha Ben Yosef Date: Sun, 10 Nov 2024 10:12:42 +0200 Subject: [PATCH 3/8] Support both user profile and space seamlessly --- sagemaker_ssh_helper/ide.py | 93 ++++++++----------- sagemaker_ssh_helper/interactive_sagemaker.py | 36 ++++--- sagemaker_ssh_helper/manager.py | 19 +--- sagemaker_ssh_helper/sm-local-ssh-ide | 89 +++++++----------- sagemaker_ssh_helper/sm_ssh.py | 36 +++---- sagemaker_ssh_helper/wrapper.py | 2 +- tests/test_cli.py | 8 +- tests/test_ide.py | 4 +- tests/test_ssm_manager.py | 16 ++-- 9 files changed, 121 insertions(+), 182 deletions(-) diff --git a/sagemaker_ssh_helper/ide.py b/sagemaker_ssh_helper/ide.py index 7abd073..0cafa87 100644 --- a/sagemaker_ssh_helper/ide.py +++ b/sagemaker_ssh_helper/ide.py @@ -46,13 +46,13 @@ def __init__(self, arn, version_arn) -> None: class SSHIDE: logger = logging.getLogger('sagemaker-ssh-helper:SSHIDE') - def __init__(self, domain_id: str, user: str = None, region_name: str = None, space: str = None): - self.user = user + def __init__(self, domain_id: str, user_or_space: str = None, region_name: str = None, is_user_profile: bool = True): + self.user_or_space = user_or_space self.domain_id = domain_id - self.space = space self.current_region = region_name or boto3.session.Session().region_name self.client = boto3.client('sagemaker', region_name=self.current_region) self.ssh_log = SSHLog(region_name=self.current_region) + self.is_user_profile = is_user_profile def create_ssh_kernel_app(self, app_name: str, image_name_or_arn='sagemaker-datascience-38', @@ -110,16 +110,16 @@ def get_app_status(self, app_name: str, app_type: str = 'KernelGateway') -> IDEA """ response = None - describe_app_params = { + describe_app_request_params = { "DomainId": self.domain_id, "AppType": app_type, "AppName": app_name, } - describe_app_params.update({"UserProfileName": self.user} if self.user else {"SpaceName": self.space}) + describe_app_request_params.update({"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space}) try: - response = self.client.describe_app(**describe_app_params) + response = self.client.describe_app(**describe_app_request_params) except ClientError as e: error_code = e.response.get("Error", {}).get("Code") if error_code == 'ResourceNotFound': @@ -142,15 +142,15 @@ def delete_app(self, app_name, app_type, wait: bool = True): self.logger.info(f"Deleting app {app_name}") try: - delete_app_params = { + delete_app_request_params = { "DomainId": self.domain_id, "AppType": app_type, "AppName": app_name, } - delete_app_params.update({"UserProfileName": self.user} if self.user else {"SpaceName": self.space}) + delete_app_request_params.update({"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space}) - _ = self.client.delete_app(**delete_app_params) + _ = self.client.delete_app(**delete_app_request_params) except ClientError as e: # probably, already deleted code = e.response.get("Error", {}).get("Code") @@ -181,16 +181,16 @@ def create_app(self, app_name, app_type, instance_type, image_arn, if lifecycle_arn: resource_spec['LifecycleConfigArn'] = lifecycle_arn - create_app_params = { + create_app_request_params = { "DomainId": self.domain_id, "AppType": app_type, "AppName": app_name, "ResourceSpec": resource_spec, } - create_app_params.update({"UserProfileName": self.user} if self.user else {"SpaceName": self.space}) + create_app_request_params.update({"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space}) - _ = self.client.create_app(**create_app_params) + _ = self.client.create_app(**create_app_request_params) status = self.get_app_status(app_name) while status.is_pending(): self.logger.info(f"Waiting for the InService status. Current status: {status}") @@ -210,66 +210,47 @@ def resolve_sagemaker_kernel_image_arn(self, image_name): sagemaker_account_id = "470317259841" # eu-west-1, TODO: check all images return f"arn:aws:sagemaker:{self.current_region}:{sagemaker_account_id}:image/{image_name}" - def print_kernel_instance_id(self, app_name, timeout_in_sec, index: int = 0): - print(self.get_kernel_instance_id(app_name, timeout_in_sec, index)) + def print_instance_id(self, app_name, timeout_in_sec, index: int = 0): + print(self.get_instance_id(app_name, timeout_in_sec, index)) - def print_space_instance_id(self, app_name, timeout_in_sec, index: int = 0): - print(self.get_space_instance_ids(app_name, timeout_in_sec)[index]) - - def get_kernel_instance_id(self, app_name, timeout_in_sec, index: int = 0, - not_earlier_than_timestamp: int = 0): - ids = self.get_kernel_instance_ids(app_name, timeout_in_sec, not_earlier_than_timestamp) + def get_instance_id(self, app_name, timeout_in_sec, index: int = 0, + not_earlier_than_timestamp: int = 0): + ids = self.get_instance_ids(app_name, timeout_in_sec, not_earlier_than_timestamp) if len(ids) == 0: - raise ValueError(f"No kernel instances found for app {app_name}") + raise ValueError(f"No instances found for app {app_name}") return ids[index] - def get_kernel_instance_ids(self, app_name: str, timeout_in_sec: int, not_earlier_than_timestamp: int = 0): - self.logger.info(f"Resolving IDE instance IDs for app '{app_name}' through SSM tags " - f"in domain '{self.domain_id}' for user '{self.user}'") + def get_instance_ids(self, app_name: str, timeout_in_sec: int, not_earlier_than_timestamp: int = 0): + self.logger.info(f"Resolving IDE instance IDs for app '{app_name}' through SSM tags in domain '{self.domain_id}' " + f"for {f'user' if self.is_user_profile else f'space'} '{self.user_or_space}'") self.log_urls(app_name) - if self.domain_id and self.user: - result = SSMManager().get_studio_user_kgw_instance_ids(self.domain_id, self.user, app_name, - timeout_in_sec, not_earlier_than_timestamp) - elif self.user: - self.logger.warning(f"Domain ID is not set. Will attempt to connect to the latest " - f"active kernel gateway with the name {app_name} in the region {self.current_region} " - f"for user profile {self.user}") - result = SSMManager().get_studio_user_kgw_instance_ids("", self.user, app_name, - timeout_in_sec, not_earlier_than_timestamp) - else: - self.logger.warning(f"Domain ID or user profile name are not set. Will attempt to connect to the latest " - f"active kernel gateway with the name {app_name} in the region {self.current_region}") - result = SSMManager().get_studio_kgw_instance_ids(app_name, timeout_in_sec, not_earlier_than_timestamp) - return result - def get_space_instance_ids(self, app_name, timeout_in_sec): - self.logger.info("Resolving IDE instance IDs through SSM tags") - self.log_urls(app_name) - if self.domain_id and self.space: - result = SSMManager().get_studio_space_app_instance_ids(self.domain_id, self.space, app_name, timeout_in_sec) - elif self.space: + if self.domain_id and self.user_or_space: + result = SSMManager().get_studio_instance_ids(self.domain_id, self.user_or_space, app_name, + timeout_in_sec, not_earlier_than_timestamp, is_user_profile=self.is_user_profile) + elif self.user_or_space: self.logger.warning(f"Domain ID is not set. Will attempt to connect to the latest " - f"active kernel gateway with the name {app_name} in the region {self.current_region} " - f"for space {self.space}") - result = SSMManager().get_studio_space_app_instance_ids("", self.space, app_name, - timeout_in_sec) + f"active {app_name} in the region {self.current_region} " + f"for {'user' if self.is_user_profile else 'space'} {self.user_or_space}") + result = SSMManager().get_studio_instance_ids("", self.user_or_space, app_name, + timeout_in_sec, not_earlier_than_timestamp, is_user_profile=self.is_user_profile) else: - self.logger.warning(f"Domain ID or space name are not set. Will attempt to connect to the latest " - f"active kernel gateway with the name {app_name} in the region {self.current_region}") - result = SSMManager().get_studio_app_instance_ids(app_name, timeout_in_sec) + self.logger.warning(f"Domain ID or {'user' if self.is_user_profile else 'space'} are not set. Will attempt to connect to the latest " + f"active {app_name} in the region {self.current_region}") + result = SSMManager().get_studio_kgw_instance_ids(app_name, timeout_in_sec, not_earlier_than_timestamp) return result def log_urls(self, app_name): self.logger.info(f"Remote logs are at {self.get_cloudwatch_url(app_name)}") - if self.domain_id and self.user: - self.logger.info(f"Remote apps metadata is at {self.get_user_metadata_url()}") + if self.domain_id: + self.logger.info(f"Remote apps metadata is at {self.get_user_or_space_metadata_url()}") def get_cloudwatch_url(self, app_name): - return self.ssh_log.get_ide_cloudwatch_url(self.domain_id, self.user, app_name) + return self.ssh_log.get_ide_cloudwatch_url(self.domain_id, self.user_or_space, app_name) - def get_user_metadata_url(self): - return self.ssh_log.get_ide_metadata_url(self.domain_id, self.user) + def get_user_or_space_metadata_url(self): + return self.ssh_log.get_ide_metadata_url(self.domain_id, self.user_or_space) def create_and_attach_image(self, image_name, ecr_image_name, role_arn, diff --git a/sagemaker_ssh_helper/interactive_sagemaker.py b/sagemaker_ssh_helper/interactive_sagemaker.py index 21fd695..61e09d6 100644 --- a/sagemaker_ssh_helper/interactive_sagemaker.py +++ b/sagemaker_ssh_helper/interactive_sagemaker.py @@ -31,16 +31,15 @@ def set_ping_status(self, ping_status): class SageMakerStudioApp(SageMakerCoreApp): - def __init__(self, domain_id: str, app_name: str, app_type: str, app_status: IDEAppStatus, user_profile_name: str = None, - space_name: str = None) -> None: + def __init__(self, domain_id: str, app_name: str, app_type: str, app_status: IDEAppStatus, user_profile_or_space_name: str, + is_user_profile: bool = True) -> None: super().__init__() self.app_status = app_status self.app_type = app_type self.app_name = app_name - self.user_profile_name = user_profile_name - self.space_name = space_name + self.user_profile_or_space_name = user_profile_or_space_name self.domain_id = domain_id - self.resource_type = "ide" if user_profile_name else "ide-space" + self.resource_type = "ide" if is_user_profile else "ide-space" def __str__(self) -> str: return "{0:<16} {1:<18} {2:<12} {5}.{4}.{3}.{6}".format( @@ -48,7 +47,7 @@ def __str__(self) -> str: self.app_type, str(self.app_status), self.domain_id, - self.user_profile_name if self.user_profile_name else self.space_name, + self.user_profile_or_space_name, self.app_name, SageMakerSecureShellHelper.type_to_fqdn(self.resource_type) ) @@ -163,13 +162,13 @@ def list_ide_apps(self) -> List[SageMakerStudioApp]: if 'SpaceName' in app_dict: space_name = app_dict['SpaceName'] logging.info("Found app %s of type %s for space %s" % (app_name, app_type, space_name)) - app_status = SSHIDE(domain_id, None, self.region, space_name).get_app_status(app_name, app_type) - result.append(SageMakerStudioApp(domain_id, app_dict['AppName'], app_dict['AppType'], app_status, space_name=space_name)) + app_status = SSHIDE(domain_id, space_name, self.region, is_user_profile=False).get_app_status(app_name, app_type) + result.append(SageMakerStudioApp(domain_id, app_dict['AppName'], app_dict['AppType'], app_status, user_profile_or_space_name=space_name, is_user_profile=False)) elif app_type in ['JupyterServer', 'KernelGateway']: user_profile_name = app_dict['UserProfileName'] logging.info("Found app %s of type %s for user %s" % (app_name, app_type, user_profile_name)) - app_status = SSHIDE(domain_id, user_profile_name, self.region).get_app_status(app_name, app_type) - result.append(SageMakerStudioApp(domain_id, app_dict['AppName'], app_dict['AppType'], app_status, user_profile_name=user_profile_name)) + app_status = SSHIDE(domain_id, user_profile_name, self.region, is_user_profile=True).get_app_status(app_name, app_type) + result.append(SageMakerStudioApp(domain_id, app_dict['AppName'], app_dict['AppType'], app_status, user_profile_or_space_name=user_profile_name, is_user_profile=True)) else: logging.info("Unsupported app type %s" % app_type) pass # We don't support other types like 'DetailedProfiler' @@ -289,14 +288,14 @@ def __init__(self, sagemaker: SageMaker, manager: SSMManager, self.manager = manager self.log = log - def list_studio_ide_apps_for_user_and_domain(self, domain_id: Optional[str], user_profile_name: Optional[str]): + def list_studio_ide_apps_for_user_or_space_and_domain(self, domain_id: Optional[str], user_profile_or_space_name: Optional[str]): managed_instances = self.manager.list_all_instances_and_fetch_tags() sagemaker_apps = self.sagemaker.list_ide_apps() result = [] for sagemaker_app in sagemaker_apps: if (sagemaker_app.domain_id == domain_id or domain_id is None or domain_id == "") \ - and (sagemaker_app.user_profile_name == user_profile_name or user_profile_name is None - or user_profile_name == ""): + and (sagemaker_app.user_profile_or_space_name == user_profile_or_space_name or user_profile_or_space_name is None + or user_profile_or_space_name == ""): instance_id = self._find_latest_app_instance_id(managed_instances, sagemaker_app) if instance_id: tags = managed_instances[instance_id] @@ -307,15 +306,15 @@ def list_studio_ide_apps_for_user_and_domain(self, domain_id: Optional[str], use return result def print_studio_ide_apps_for_user_and_domain(self, domain_id: str, user_profile_name: str): - apps: List[SageMakerStudioApp] = self.list_studio_ide_apps_for_user_and_domain(domain_id, user_profile_name) + apps: List[SageMakerStudioApp] = self.list_studio_ide_apps_for_user_or_space_and_domain(domain_id, user_profile_name) for app in apps: print(app) - def list_studio_ide_apps_for_user(self, user_profile_name: str): - return self.list_studio_ide_apps_for_user_and_domain(None, user_profile_name) + def list_studio_ide_apps_for_user_or_space(self, user_profile_or_space_name: str): + return self.list_studio_ide_apps_for_user_or_space_and_domain(None, user_profile_or_space_name) def list_studio_ide_apps(self): - return self.list_studio_ide_apps_for_user_and_domain(None, None) + return self.list_studio_ide_apps_for_user_or_space_and_domain(None, None) @staticmethod def _find_latest_instance_id(managed_instances: Dict[str, Dict[str, str]], @@ -341,8 +340,7 @@ def _find_latest_app_instance_id(managed_instances: Dict[str, Dict[str, str]], s arn = tags['SSHResourceArn'] if 'SSHResourceArn' in tags else '' timestamp = int(tags['SSHTimestamp']) if 'SSHTimestamp' in tags else 0 if (':app/' in arn and arn.endswith(f"/{sagemaker_app.app_name}") - and (f"/{sagemaker_app.user_profile_name}/" in arn or - f"/{sagemaker_app.space_name}/" in arn) + and f"/{sagemaker_app.user_profile_or_space_name}/" in arn and f"/{sagemaker_app.domain_id}/" in arn and timestamp > max_timestamp): result = managed_instance_id diff --git a/sagemaker_ssh_helper/manager.py b/sagemaker_ssh_helper/manager.py index 9a70764..7ce4cc8 100644 --- a/sagemaker_ssh_helper/manager.py +++ b/sagemaker_ssh_helper/manager.py @@ -113,23 +113,14 @@ def get_transformer_instance_ids(self, transform_job_name, timeout_in_sec=0): self.logger.info(f"Querying SSM instance IDs for transform job {transform_job_name}") return self.get_instance_ids('transform-job', transform_job_name, timeout_in_sec) - def get_studio_space_app_instance_ids(self, domain_id, space_name, app_name, timeout_in_sec=0): - self.logger.info(f"Querying SSM instance IDs for SageMaker Studio space {app_name}") + def get_studio_instance_ids(self, domain_id, user_profile_or_space_name, app_name, timeout_in_sec=0, not_earlier_than_timestamp: int = 0, is_user_profile=False): + self.logger.info(f"Querying SSM instance IDs for app '{app_name}' in SageMaker Studio {'kernel gateway' if is_user_profile else 'space'}: '{user_profile_or_space_name}'") if not domain_id: - arn_filter = f":app/.*/{space_name}/" + arn_filter = f":app/.*/{user_profile_or_space_name}/" else: - arn_filter = f":app/{domain_id}/{space_name}/" - return self.get_instance_ids('app', f"{app_name}", timeout_in_sec, - arn_filter_regex=arn_filter) + arn_filter = f":app/{domain_id}/{user_profile_or_space_name}/" - def get_studio_user_kgw_instance_ids(self, domain_id, user_profile_name, kgw_name, timeout_in_sec=0, - not_earlier_than_timestamp: int = 0): - self.logger.info(f"Querying SSM instance IDs for SageMaker Studio kernel gateway: '{kgw_name}'") - if not domain_id: - arn_filter = f":app/.*/{user_profile_name}/" - else: - arn_filter = f":app/{domain_id}/{user_profile_name}/" - return self.get_instance_ids('app', f"{kgw_name}", timeout_in_sec, + return self.get_instance_ids('app', f"{app_name}", timeout_in_sec, arn_filter_regex=arn_filter, not_earlier_than_timestamp=not_earlier_than_timestamp) diff --git a/sagemaker_ssh_helper/sm-local-ssh-ide b/sagemaker_ssh_helper/sm-local-ssh-ide index b8c084b..ccd187f 100644 --- a/sagemaker_ssh_helper/sm-local-ssh-ide +++ b/sagemaker_ssh_helper/sm-local-ssh-ide @@ -4,11 +4,11 @@ # Run `sm-ssh -h` for the high-level interface. # Commands: -# [--domain-id ] [--user-profile-name ] connect-app [--ssh-only] [] -# [--domain-id ] [--user-profile-name ] proxy-host +# [--domain-id ] [--user-profile-or-space-name ] connect-app [--ssh-only] [] +# [--domain-id ] [--user-profile-or-space-name ] proxy-host # run-command # set-domain-id -# set-user-profile-name +# set-user-profile-or-space-name # set-jb-license-server # SageMaker Studio app name for Kernel Gateways is usually the same as the hostname, @@ -16,7 +16,7 @@ # For JupyterServer app it's 'default'. # Tip: to open SageMaker Studio UI in Firefox from command line on macOS, use the following command: -# open -a Firefox $(AWS_PROFILE=terry aws sagemaker create-presigned-domain-url --domain-id d-lnwlaexample --user-profile-name terry-whitlock --query AuthorizedUrl --output text) +# open -a Firefox $(AWS_PROFILE=terry aws sagemaker create-presigned-domain-url --domain-id d-lnwlaexample --user-profile-or-space-name terry-whitlock --query AuthorizedUrl --output text) # replace with your JetBrains License Server host, or leave it as is if you don't use one @@ -38,8 +38,7 @@ source "$dir"/sm-helper-functions 2>/dev/null || source sm-helper-functions echo "sm-local-ssh-ide: Starting in $dir" DOMAIN_ID="" -USER_PROFILE_NAME="" -SPACE_NAME="" +USER_PROFILE_OR_SPACE_NAME="" if [[ "$1" == "--domain-id" ]]; then DOMAIN_ID=$2 @@ -47,39 +46,27 @@ if [[ "$1" == "--domain-id" ]]; then shift fi -if [[ "$1" == "--user-profile-name" ]]; then - USER_PROFILE_NAME=$2 - shift - shift -fi - -if [[ "$1" == "--space-name" ]]; then - SPACE_NAME=$2 +if [[ "$1" == "--user-profile-or-space-name" ]]; then + USER_PROFILE_OR_SPACE_NAME=$2 shift shift fi COMMAND=$1 -if [[ "$COMMAND" != "set-user-profile-name" && "$COMMAND" != "set-domain-id" \ - && "$DOMAIN_ID" == "" && "$USER_PROFILE_NAME" == "" ]]; then +if [[ "$COMMAND" != "set-user-profile-or-space-name" && "$COMMAND" != "set-domain-id" \ + && "$DOMAIN_ID" == "" && "$USER_PROFILE_OR_SPACE_NAME" == "" ]]; then if [ -f ~/.sm-studio-domain-id ]; then DOMAIN_ID="$(cat ~/.sm-studio-domain-id)" else echo "sm-local-ssh-ide: WARNING: SageMaker Studio domain ID is not set."\ "Run 'sm-local-ssh-ide set-domain-id' to override." fi - if [ -f ~/.sm-studio-user-profile-name ]; then - USER_PROFILE_NAME="$(cat ~/.sm-studio-user-profile-name)" + if [ -f ~/.sm-studio-user-profile-or-space-name ]; then + USER_PROFILE_OR_SPACE_NAME="$(cat ~/.sm-studio-user-profile-or-space-name)" else echo "sm-local-ssh-ide: WARNING: SageMaker Studio user profile name is not set."\ - "Run 'sm-local-ssh-ide set-user-profile-name' to override." - fi - if [ -f ~/.sm-studio-space-name ]; then - SPACE_NAME="$(cat ~/.sm-studio-space-name)" - else - echo "sm-local-ssh-ide: WARNING: SageMaker Studio space name is not set."\ - "Run 'sm-local-ssh-ide set-space-name' to override." + "Run 'sm-local-ssh-ide set-user-profile-or-space-name' to override." fi fi @@ -90,31 +77,31 @@ if [[ "$COMMAND" == "proxy-host" ]]; then _check_ssh_proxy_host_name "$SM_SSH_HOST_NAME" "$SM_RESOURCE_TYPE" _export_ssh_key_env_var "$SM_SSH_HOST_NAME" - INSTANCE_ID=$(_generate_key_and_print_instance_id "$SM_SSH_HOST_NAME" "$DOMAIN_ID" "$USER_PROFILE_NAME") + INSTANCE_ID=$(_generate_key_and_print_instance_id "$SM_SSH_HOST_NAME" "$DOMAIN_ID" "$USER_PROFILE_OR_SPACE_NAME") sm-local-start-ssh --proxy-setup-only "${INSTANCE_ID}" -elif [[ "$COMMAND" == "connect-app" ]]; then - - SM_STUDIO_KGW_NAME="$2" +elif [[ "$COMMAND" == connect* ]]; then + SM_STUDIO_APP_NAME="$2" OPTIONS="$3" + if [[ "$COMMAND" == "connect-app" ]]; then + IS_USER_PROFILE="True" + elif [[ "$COMMAND" == "connect-space-app" ]]; then + IS_USER_PROFILE="False" + else + echo "ERROR: Unknown command: '$COMMAND'" + exit 1 + fi + # shellcheck disable=SC2091 - if [ -z "$SPACE_NAME" ]; then - INSTANCE_ID=$(python < ~/.sm-studio-domain-id -elif [[ "$COMMAND" == "set-user-profile-name" ]]; then - USER_PROFILE_NAME="$2" - if [[ "$USER_PROFILE_NAME" == "" ]]; then - echo "sm-local-ssh-ide: ERROR: argument is expected" - exit 1 - fi - echo "sm-local-ssh-ide: Saving SageMaker Studio user profile name into ~/.sm-studio-user-profile-name" - echo "$USER_PROFILE_NAME" > ~/.sm-studio-user-profile-name - -elif [[ "$COMMAND" == "set-space-name" ]]; then - SPACE_NAME="$(echo "$2" | tr '[:upper:]' '[:lower:]')" - if [[ "$SPACE_NAME" == "" ]]; then - echo "sm-local-ssh-ide: ERROR: argument is expected" +elif [[ "$COMMAND" == "set-user-profile-or-space-name" ]]; then + USER_PROFILE_OR_SPACE_NAME="$2" + if [[ "$USER_PROFILE_OR_SPACE_NAME" == "" ]]; then + echo "sm-local-ssh-ide: ERROR: argument is expected" exit 1 fi - echo "sm-local-ssh-ide: Saving SageMaker Studio user profile name into ~/.sm-studio-space-name" - echo "$SPACE_NAME" > ~/.sm-studio-space-name + echo "sm-local-ssh-ide: Saving SageMaker Studio user profile name into ~/.sm-studio-user-profile-or-space-name" + echo "$USER_PROFILE_OR_SPACE_NAME" > ~/.sm-studio-user-profile-or-space-name elif [[ "$COMMAND" == "run-command" ]]; then diff --git a/sagemaker_ssh_helper/sm_ssh.py b/sagemaker_ssh_helper/sm_ssh.py index 98e43cb..cc6c8c6 100644 --- a/sagemaker_ssh_helper/sm_ssh.py +++ b/sagemaker_ssh_helper/sm_ssh.py @@ -15,14 +15,14 @@ class SageMakerSecureShellHelper: - resources = ["ide", "training", "processing", "transform", "inference", "notebook"] + resources = ["ide", "space-ide", "training", "processing", "transform", "inference", "notebook"] @staticmethod def fqdn_to_type(fqdn: str) -> str: if fqdn.endswith(".studio.sagemaker") or fqdn == "studio.sagemaker": return "ide" elif fqdn.endswith(".space.sagemaker") or fqdn == "space.sagemaker": - return "ide-space" + return "space-ide" elif fqdn.endswith(".notebook.sagemaker") or fqdn == "notebook.sagemaker": return "notebook" elif fqdn.endswith(".training.sagemaker") or fqdn == "training.sagemaker": @@ -40,7 +40,7 @@ def fqdn_to_type(fqdn: str) -> str: def type_to_fqdn(cls, resource_type): if resource_type == "ide": return "studio.sagemaker" - elif resource_type == "ide-space": + elif resource_type == "space-ide": return "space.sagemaker" elif resource_type == "notebook": return "notebook.sagemaker" @@ -72,7 +72,7 @@ def fqdn_to_studio_domain_id(fqdn: str) -> str: return '' @staticmethod - def fqdn_to_studio_user_name(fqdn: str) -> str: + def fqdn_to_studio_user_or_space_name(fqdn: str) -> str: if fqdn.count('.') == 4: return fqdn.split('.')[1] if fqdn.count('.') == 3: @@ -80,31 +80,19 @@ def fqdn_to_studio_user_name(fqdn: str) -> str: else: return '' - fqdn_to_studio_space_name = fqdn_to_studio_user_name - @classmethod def _get_arguments(cls, fqdn, resource, command): - user_profile_name = "" - space_name = "" - - if resource.startswith("ide"): + if resource.endswith("ide"): domain_id = SageMakerSecureShellHelper.fqdn_to_studio_domain_id(fqdn) - if resource.endswith("space"): - space_name = SageMakerSecureShellHelper.fqdn_to_studio_space_name(fqdn) - else: - user_profile_name = SageMakerSecureShellHelper.fqdn_to_studio_user_name(fqdn) - - arguments = ["bash", "sm-local-ssh-ide", "--domain-id", domain_id] - - if domain_id and user_profile_name: - arguments.extend(["--user-profile-name", user_profile_name]) - elif domain_id and space_name: - arguments.extend(["--space-name", space_name]) + user_profile_or_space_name = SageMakerSecureShellHelper.fqdn_to_studio_user_or_space_name(fqdn) + arguments = ["bash", "sm-local-ssh-ide", "--domain-id", domain_id, "--user-profile-or-space-name", user_profile_or_space_name] else: arguments = ["bash", f"sm-local-ssh-{resource}"] - if resource.startswith("ide"): + if resource == "ide": arguments.append(cls._sm_ssh_command_to_local_ssh_command(command, "app")) + elif resource == "space-ide": + arguments.append(cls._sm_ssh_command_to_local_ssh_command(command, "space-app")) elif resource == "notebook": arguments.append(cls._sm_ssh_command_to_local_ssh_command(command, "notebook")) elif resource == "training": @@ -153,8 +141,8 @@ def list(self, fqdn): # if-then-else branch for every resource type: if resource.startswith("ide"): domain_id = SageMakerSecureShellHelper.fqdn_to_studio_domain_id(fqdn) - user_profile_name = SageMakerSecureShellHelper.fqdn_to_studio_user_name(fqdn) - interactive_sagemaker.print_studio_ide_apps_for_user_and_domain(domain_id, user_profile_name) + user_profile_or_space_name = SageMakerSecureShellHelper.fqdn_to_studio_user_or_space_name(fqdn) + interactive_sagemaker.print_studio_ide_apps_for_user_and_domain(domain_id, user_profile_or_space_name) elif resource == "notebook": interactive_sagemaker.print_notebook_instances() elif resource == "training": diff --git a/sagemaker_ssh_helper/wrapper.py b/sagemaker_ssh_helper/wrapper.py index a492da7..12e3426 100644 --- a/sagemaker_ssh_helper/wrapper.py +++ b/sagemaker_ssh_helper/wrapper.py @@ -656,7 +656,7 @@ def attach(cls, domain_id, user_profile_name, app_name, sagemaker_session: Sessi return result def get_instance_ids(self, retry: int = None, timeout_in_sec: int = 900): - return self.ide.get_kernel_instance_ids(self.app_name, timeout_in_sec=timeout_in_sec, + return self.ide.get_instance_ids(self.app_name, timeout_in_sec=timeout_in_sec, not_earlier_than_timestamp=self.not_earlier_than_timestamp) def get_cloudwatch_url(self): diff --git a/tests/test_cli.py b/tests/test_cli.py index 5bfe289..f74fbc3 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -30,18 +30,18 @@ def test_fqdn_to_name(): def test_fqdn_to_studio_user_and_domain(): sm_ssh = SageMakerSecureShellHelper() assert sm_ssh.fqdn_to_studio_domain_id("ssh-training-job.training.sagemaker") == "" - assert sm_ssh.fqdn_to_studio_user_name("ssh-training-job.training.sagemaker") == "" + assert sm_ssh.fqdn_to_studio_user_or_space_name("ssh-training-job.training.sagemaker") == "" assert sm_ssh.fqdn_to_studio_domain_id("ssh-test-ds2-cpu.studio.sagemaker") == "" - assert sm_ssh.fqdn_to_studio_user_name("ssh-test-ds2-cpu.studio.sagemaker") == "" + assert sm_ssh.fqdn_to_studio_user_or_space_name("ssh-test-ds2-cpu.studio.sagemaker") == "" assert sm_ssh.fqdn_to_studio_domain_id( "ssh-test-ds2-cpu.test-data-science.d-egm0dexample.studio.sagemaker" ) == "d-egm0dexample" - assert sm_ssh.fqdn_to_studio_user_name( + assert sm_ssh.fqdn_to_studio_user_or_space_name( "ssh-test-ds2-cpu.test-data-science.d-egm0dexample.studio.sagemaker" ) == "test-data-science" assert sm_ssh.fqdn_to_studio_domain_id( "test-data-science.d-egm0dexample.studio.sagemaker" ) == "d-egm0dexample" - assert sm_ssh.fqdn_to_studio_user_name( + assert sm_ssh.fqdn_to_studio_user_or_space_name( "test-data-science.d-egm0dexample.studio.sagemaker" ) == "test-data-science" diff --git a/tests/test_ide.py b/tests/test_ide.py index 625a1c8..0325386 100644 --- a/tests/test_ide.py +++ b/tests/test_ide.py @@ -323,7 +323,7 @@ def test_studio_default_domain_multiple_users(request): @pytest.mark.parametrize('user_profile_name', ['test-firefox']) def test_studio_notebook_in_firefox(request, user_profile_name): - ide = SSHIDE(request.config.getini('sagemaker_studio_domain'), user_profile_name) + ide = SSHIDE(request.config.getini('sagemaker_studio_domain'), user_profile_name, is_user_profile=True) local_user_id = os.environ['LOCAL_USER_ID'] jb_server_host = os.environ['JB_LICENSE_SERVER_HOST'] @@ -371,7 +371,7 @@ def test_studio_notebook_in_firefox(request, user_profile_name): browser_automation.restart_kernel_and_run_all_cells() data_science_kernel = "sagemaker-data-science-ml-m5-large-6590da95dc67eec021b14bedc036" # noqa - studio_id = ide.get_kernel_instance_id( + studio_id = ide.get_instance_id( data_science_kernel, timeout_in_sec=300, not_earlier_than_timestamp=current_time_stamp diff --git a/tests/test_ssm_manager.py b/tests/test_ssm_manager.py index a4cb2fc..742348f 100644 --- a/tests/test_ssm_manager.py +++ b/tests/test_ssm_manager.py @@ -199,9 +199,10 @@ def test_can_filter_by_domain_and_user(): }, }) - ids = manager.get_studio_user_kgw_instance_ids( + ids = manager.get_studio_instance_ids( "d-0123456789bc", "default-1111111111111", - "sagemaker-data-science-ml-m5-large-1234567890abcdef0" + "sagemaker-data-science-ml-m5-large-1234567890abcdef0", + is_user_profile=True ) assert len(ids) == 1 assert ids[0] == "mi-01234567890abcd08" @@ -234,9 +235,10 @@ def test_can_filter_by_user_with_latest_domain(): }, }) - ids = manager.get_studio_user_kgw_instance_ids( + ids = manager.get_studio_instance_ids( "", "default-1111111111111", - "sagemaker-data-science-ml-m5-large-1234567890abcdef0" + "sagemaker-data-science-ml-m5-large-1234567890abcdef0", + is_user_profile=True ) assert len(ids) == 2 assert ids[0] == "mi-01234567890abcd08" @@ -327,7 +329,7 @@ def test_can_list_ssh_and_non_ssh_instances(): interactive_sagemaker = InteractiveSageMaker(sagemaker, manager) - apps = interactive_sagemaker.list_studio_ide_apps_for_user_and_domain( + apps = interactive_sagemaker.list_studio_ide_apps_for_user_or_space_and_domain( "d-0123456789bc", "janedoe", ) assert len(apps) == 4 @@ -336,12 +338,12 @@ def test_can_list_ssh_and_non_ssh_instances(): assert apps[0].ssm_instance_id == "mi-01234567890abcd05" assert apps[0].ssh_owner == "AIDACKCEVSQ6C2EXAMPLE:janedoe@SSO" - apps = interactive_sagemaker.list_studio_ide_apps_for_user_and_domain( + apps = interactive_sagemaker.list_studio_ide_apps_for_user_or_space_and_domain( "d-0123456789bc", "terry", ) assert len(apps) == 2 - apps = interactive_sagemaker.list_studio_ide_apps_for_user( + apps = interactive_sagemaker.list_studio_ide_apps_for_user_or_space( "terry", ) assert len(apps) == 4 From 0510ee3cbfc5b4d8d5e9e11198afcac7e7dcdfe5 Mon Sep 17 00:00:00 2001 From: Elisha Ben Yosef Date: Sun, 10 Nov 2024 10:25:19 +0200 Subject: [PATCH 4/8] cleanup code --- sagemaker_ssh_helper/ide.py | 15 ++++++++----- sagemaker_ssh_helper/interactive_sagemaker.py | 8 ++++--- sagemaker_ssh_helper/manager.py | 4 ---- sagemaker_ssh_helper/sm_ssh.py | 4 ++++ tests/test_ssm_manager.py | 22 +++++++++---------- 5 files changed, 29 insertions(+), 24 deletions(-) diff --git a/sagemaker_ssh_helper/ide.py b/sagemaker_ssh_helper/ide.py index 0cafa87..1af179d 100644 --- a/sagemaker_ssh_helper/ide.py +++ b/sagemaker_ssh_helper/ide.py @@ -116,7 +116,8 @@ def get_app_status(self, app_name: str, app_type: str = 'KernelGateway') -> IDEA "AppName": app_name, } - describe_app_request_params.update({"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space}) + describe_app_request_params.update( + {"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space}) try: response = self.client.describe_app(**describe_app_request_params) @@ -148,7 +149,8 @@ def delete_app(self, app_name, app_type, wait: bool = True): "AppName": app_name, } - delete_app_request_params.update({"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space}) + delete_app_request_params.update( + {"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space}) _ = self.client.delete_app(**delete_app_request_params) except ClientError as e: @@ -188,7 +190,8 @@ def create_app(self, app_name, app_type, instance_type, image_arn, "ResourceSpec": resource_spec, } - create_app_request_params.update({"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space}) + create_app_request_params.update( + {"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space}) _ = self.client.create_app(**create_app_request_params) status = self.get_app_status(app_name) @@ -235,12 +238,12 @@ def get_instance_ids(self, app_name: str, timeout_in_sec: int, not_earlier_than_ result = SSMManager().get_studio_instance_ids("", self.user_or_space, app_name, timeout_in_sec, not_earlier_than_timestamp, is_user_profile=self.is_user_profile) else: - self.logger.warning(f"Domain ID or {'user' if self.is_user_profile else 'space'} are not set. Will attempt to connect to the latest " - f"active {app_name} in the region {self.current_region}") + self.logger.warning( + f"Domain ID or {'user' if self.is_user_profile else 'space'} are not set. Will attempt to connect to the latest " + f"active {app_name} in the region {self.current_region}") result = SSMManager().get_studio_kgw_instance_ids(app_name, timeout_in_sec, not_earlier_than_timestamp) return result - def log_urls(self, app_name): self.logger.info(f"Remote logs are at {self.get_cloudwatch_url(app_name)}") if self.domain_id: diff --git a/sagemaker_ssh_helper/interactive_sagemaker.py b/sagemaker_ssh_helper/interactive_sagemaker.py index 61e09d6..996e283 100644 --- a/sagemaker_ssh_helper/interactive_sagemaker.py +++ b/sagemaker_ssh_helper/interactive_sagemaker.py @@ -31,7 +31,7 @@ def set_ping_status(self, ping_status): class SageMakerStudioApp(SageMakerCoreApp): - def __init__(self, domain_id: str, app_name: str, app_type: str, app_status: IDEAppStatus, user_profile_or_space_name: str, + def __init__(self, domain_id: str, user_profile_or_space_name: str, app_name: str, app_type: str, app_status: IDEAppStatus, is_user_profile: bool = True) -> None: super().__init__() self.app_status = app_status @@ -163,12 +163,14 @@ def list_ide_apps(self) -> List[SageMakerStudioApp]: space_name = app_dict['SpaceName'] logging.info("Found app %s of type %s for space %s" % (app_name, app_type, space_name)) app_status = SSHIDE(domain_id, space_name, self.region, is_user_profile=False).get_app_status(app_name, app_type) - result.append(SageMakerStudioApp(domain_id, app_dict['AppName'], app_dict['AppType'], app_status, user_profile_or_space_name=space_name, is_user_profile=False)) + result.append(SageMakerStudioApp(domain_id, user_profile_or_space_name=space_name, app_name=app_dict['AppName'], + app_type=app_dict['AppType'], app_status=app_status, is_user_profile=False)) elif app_type in ['JupyterServer', 'KernelGateway']: user_profile_name = app_dict['UserProfileName'] logging.info("Found app %s of type %s for user %s" % (app_name, app_type, user_profile_name)) app_status = SSHIDE(domain_id, user_profile_name, self.region, is_user_profile=True).get_app_status(app_name, app_type) - result.append(SageMakerStudioApp(domain_id, app_dict['AppName'], app_dict['AppType'], app_status, user_profile_or_space_name=user_profile_name, is_user_profile=True)) + result.append(SageMakerStudioApp(domain_id, user_profile_or_space_name=user_profile_name, app_name=app_dict['AppName'], + app_type=app_dict['AppType'], app_status=app_status, is_user_profile=True)) else: logging.info("Unsupported app type %s" % app_type) pass # We don't support other types like 'DetailedProfiler' diff --git a/sagemaker_ssh_helper/manager.py b/sagemaker_ssh_helper/manager.py index 7ce4cc8..988c3e5 100644 --- a/sagemaker_ssh_helper/manager.py +++ b/sagemaker_ssh_helper/manager.py @@ -129,10 +129,6 @@ def get_studio_kgw_instance_ids(self, kgw_name, timeout_in_sec=0, not_earlier_th return self.get_instance_ids('app', f"{kgw_name}", timeout_in_sec, not_earlier_than_timestamp) - def get_studio_app_instance_ids(self, app_name, timeout_in_sec=0): - self.logger.info(f"Querying SSM instance IDs for SageMaker Studio space {app_name}") - return self.get_instance_ids('app', f"{app_name}", timeout_in_sec) - def get_notebook_instance_ids(self, instance_name, timeout_in_sec=0): self.logger.info(f"Querying SSM instance IDs for SageMaker notebook instance {instance_name}") return self.get_instance_ids('notebook-instance', f"{instance_name}", diff --git a/sagemaker_ssh_helper/sm_ssh.py b/sagemaker_ssh_helper/sm_ssh.py index cc6c8c6..305c4ec 100644 --- a/sagemaker_ssh_helper/sm_ssh.py +++ b/sagemaker_ssh_helper/sm_ssh.py @@ -82,9 +82,13 @@ def fqdn_to_studio_user_or_space_name(fqdn: str) -> str: @classmethod def _get_arguments(cls, fqdn, resource, command): + domain_id = "" + user_profile_or_space_name = "" + if resource.endswith("ide"): domain_id = SageMakerSecureShellHelper.fqdn_to_studio_domain_id(fqdn) user_profile_or_space_name = SageMakerSecureShellHelper.fqdn_to_studio_user_or_space_name(fqdn) + if domain_id and user_profile_or_space_name: arguments = ["bash", "sm-local-ssh-ide", "--domain-id", domain_id, "--user-profile-or-space-name", user_profile_or_space_name] else: arguments = ["bash", f"sm-local-ssh-{resource}"] diff --git a/tests/test_ssm_manager.py b/tests/test_ssm_manager.py index 742348f..84d8ef9 100644 --- a/tests/test_ssm_manager.py +++ b/tests/test_ssm_manager.py @@ -312,17 +312,17 @@ def test_can_list_ssh_and_non_ssh_instances(): sagemaker = SageMaker('eu-west-1') sagemaker.list_ide_apps = Mock(return_value=[ - SageMakerStudioApp("d-0123456789bc", "default", "JupyterServer", IDEAppStatus("InService"), "janedoe"), - SageMakerStudioApp("d-0123456789bc", "ssh-test-kgw", "KernelGateway", IDEAppStatus("InService"), "janedoe"), - SageMakerStudioApp("d-0123456789bc", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("InService"), "janedoe"), - SageMakerStudioApp("d-0123456789bc", "data-science-g4-no-ssh", "KernelGateway", IDEAppStatus("Offline"), "janedoe"), - SageMakerStudioApp("d-0123456789bc", "sagemaker-data-science-ml-m5-large-1234567890abcdef0", "KernelGateway", - IDEAppStatus("Offline"), "terry"), - SageMakerStudioApp("d-0123456789bc", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("Offline"), "terry"), - - SageMakerStudioApp("d-0123456789ab", "sagemaker-data-science-ml-m5-large-1234567890abcdef0", "KernelGateway", - IDEAppStatus("Offline"), "terry"), - SageMakerStudioApp("d-0123456789ab", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("Offline"), "terry"), + SageMakerStudioApp("d-0123456789bc", "janedoe", "default", "JupyterServer", IDEAppStatus("InService")), + SageMakerStudioApp("d-0123456789bc", "janedoe", "ssh-test-kgw", "KernelGateway", IDEAppStatus("InService")), + SageMakerStudioApp("d-0123456789bc", "janedoe", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("InService")), + SageMakerStudioApp("d-0123456789bc", "janedoe", "data-science-g4-no-ssh", "KernelGateway", IDEAppStatus("Offline")), + SageMakerStudioApp("d-0123456789bc", "terry", "sagemaker-data-science-ml-m5-large-1234567890abcdef0", "KernelGateway", + IDEAppStatus("Offline")), + SageMakerStudioApp("d-0123456789bc", "terry", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("Offline")), + + SageMakerStudioApp("d-0123456789ab", "terry", "sagemaker-data-science-ml-m5-large-1234567890abcdef0", "KernelGateway", + IDEAppStatus("Offline")), + SageMakerStudioApp("d-0123456789ab", "terry", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("Offline")), # LocalApp("janedoe", "AIDACKCEVSQ6C2EXAMPLE:janedoe@SSO", "macOS 13.5.1"), # LocalApp("terry", "AIDACKCEVSQ6C2EXAMPLE:terry@SSO", "Windows 10 Pro"), ]) From 074ec715cfacc6377080d5c9f51d5618b7a37bcc Mon Sep 17 00:00:00 2001 From: Elisha Ben Yosef Date: Sun, 10 Nov 2024 11:09:36 +0200 Subject: [PATCH 5/8] Fix logs location and small bugs --- sagemaker_ssh_helper/ide.py | 4 ++-- sagemaker_ssh_helper/interactive_sagemaker.py | 2 +- sagemaker_ssh_helper/log.py | 16 ++++++++++------ sagemaker_ssh_helper/sm_ssh.py | 3 ++- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/sagemaker_ssh_helper/ide.py b/sagemaker_ssh_helper/ide.py index 1af179d..c7240e0 100644 --- a/sagemaker_ssh_helper/ide.py +++ b/sagemaker_ssh_helper/ide.py @@ -250,10 +250,10 @@ def log_urls(self, app_name): self.logger.info(f"Remote apps metadata is at {self.get_user_or_space_metadata_url()}") def get_cloudwatch_url(self, app_name): - return self.ssh_log.get_ide_cloudwatch_url(self.domain_id, self.user_or_space, app_name) + return self.ssh_log.get_ide_cloudwatch_url(self.domain_id, self.user_or_space, app_name, self.is_user_profile) def get_user_or_space_metadata_url(self): - return self.ssh_log.get_ide_metadata_url(self.domain_id, self.user_or_space) + return self.ssh_log.get_ide_metadata_url(self.domain_id, self.user_or_space, self.is_user_profile) def create_and_attach_image(self, image_name, ecr_image_name, role_arn, diff --git a/sagemaker_ssh_helper/interactive_sagemaker.py b/sagemaker_ssh_helper/interactive_sagemaker.py index 996e283..8725ff2 100644 --- a/sagemaker_ssh_helper/interactive_sagemaker.py +++ b/sagemaker_ssh_helper/interactive_sagemaker.py @@ -39,7 +39,7 @@ def __init__(self, domain_id: str, user_profile_or_space_name: str, app_name: st self.app_name = app_name self.user_profile_or_space_name = user_profile_or_space_name self.domain_id = domain_id - self.resource_type = "ide" if is_user_profile else "ide-space" + self.resource_type = "ide" if is_user_profile else "space-ide" def __str__(self) -> str: return "{0:<16} {1:<18} {2:<12} {5}.{4}.{3}.{6}".format( diff --git a/sagemaker_ssh_helper/log.py b/sagemaker_ssh_helper/log.py index 0c13791..b21d5a1 100644 --- a/sagemaker_ssh_helper/log.py +++ b/sagemaker_ssh_helper/log.py @@ -197,22 +197,26 @@ def get_transform_metadata_url(self, transform_job_name): f"sagemaker/home?region={self.region_name}#" \ f"/transform-jobs/{transform_job_name}" - def get_ide_cloudwatch_url(self, domain, user, app_name): - app_type = 'JupyterServer' if app_name == 'default' else 'KernelGateway' - if user: + def get_ide_cloudwatch_url(self, domain, user_or_space, app_name, is_user_profile=True): + if is_user_profile: + app_type = 'JupyterServer' if app_name == 'default' else 'KernelGateway' + else: + app_type = 'JupyterLab' + if user_or_space: return f"https://{self.aws_console.get_console_domain()}/" \ f"cloudwatch/home?region={self.region_name}#" \ f"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252Fstudio" \ - f"$3FlogStreamNameFilter$3D{domain}$252F{user}$252F{app_type}$252F{app_name}" + f"$3FlogStreamNameFilter$3D{domain}$252F{user_or_space}$252F{app_type}$252F{app_name}" return f"https://{self.aws_console.get_console_domain()}/" \ f"cloudwatch/home?region={self.region_name}#" \ f"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252Fstudio" \ f"$3FlogStreamNameFilter$3D{app_type}$252F{app_name}" - def get_ide_metadata_url(self, domain, user): + def get_ide_metadata_url(self, domain, user_or_space, is_user_profile=True): + scope = 'user' if is_user_profile else 'space' return f"https://{self.aws_console.get_console_domain()}/" \ f"sagemaker/home?region={self.region_name}#" \ - f"/studio/{domain}/user/{user}" + f"/studio/{domain}/{scope}/{user_or_space}" def count_sns_notifications(self, topic_name: str, period: timedelta): cloudwatch_resource = boto3.resource('cloudwatch', region_name=self.region_name) diff --git a/sagemaker_ssh_helper/sm_ssh.py b/sagemaker_ssh_helper/sm_ssh.py index 305c4ec..75f1a1c 100644 --- a/sagemaker_ssh_helper/sm_ssh.py +++ b/sagemaker_ssh_helper/sm_ssh.py @@ -143,10 +143,11 @@ def list(self, fqdn): for resource in self.resources: if resource_type == resource or resource_type == "all": # if-then-else branch for every resource type: - if resource.startswith("ide"): + if resource.endswith("ide"): domain_id = SageMakerSecureShellHelper.fqdn_to_studio_domain_id(fqdn) user_profile_or_space_name = SageMakerSecureShellHelper.fqdn_to_studio_user_or_space_name(fqdn) interactive_sagemaker.print_studio_ide_apps_for_user_and_domain(domain_id, user_profile_or_space_name) + break elif resource == "notebook": interactive_sagemaker.print_notebook_instances() elif resource == "training": From a73e42e66bf4f6a567b46249a18319ed7a68d4a8 Mon Sep 17 00:00:00 2001 From: Elisha Ben Yosef Date: Sun, 10 Nov 2024 14:35:27 +0200 Subject: [PATCH 6/8] remove debug echo --- sagemaker_ssh_helper/sm-local-ssh-ide | 1 - 1 file changed, 1 deletion(-) diff --git a/sagemaker_ssh_helper/sm-local-ssh-ide b/sagemaker_ssh_helper/sm-local-ssh-ide index ccd187f..a967526 100644 --- a/sagemaker_ssh_helper/sm-local-ssh-ide +++ b/sagemaker_ssh_helper/sm-local-ssh-ide @@ -101,7 +101,6 @@ import logging; logging.basicConfig(level=logging.INFO); SSHIDE("$DOMAIN_ID", user_or_space="$USER_PROFILE_OR_SPACE_NAME", is_user_profile=$IS_USER_PROFILE).print_instance_id("$SM_STUDIO_APP_NAME", timeout_in_sec=300) EOF ) - echo "${INSTANCE_ID}" if [[ "$OPTIONS" == "--ssh-only" ]]; then echo "sm-local-ssh-ide: Connecting only SSH to local port 10022 (got the flag --ssh-only)" From a3bfb7cede5cd1646af6090f1aa9dd7f413a27a6 Mon Sep 17 00:00:00 2001 From: Elisha Ben Yosef Date: Mon, 2 Dec 2024 11:17:21 +0200 Subject: [PATCH 7/8] fix SM_RESOURCE_TYPE --- sagemaker_ssh_helper/sm-local-ssh-ide | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sagemaker_ssh_helper/sm-local-ssh-ide b/sagemaker_ssh_helper/sm-local-ssh-ide index a967526..1c9e114 100644 --- a/sagemaker_ssh_helper/sm-local-ssh-ide +++ b/sagemaker_ssh_helper/sm-local-ssh-ide @@ -72,7 +72,7 @@ fi if [[ "$COMMAND" == "proxy-host" ]]; then SM_SSH_HOST_NAME="$2" - SM_RESOURCE_TYPE="studio" + SM_RESOURCE_TYPE="space" _check_ssh_proxy_host_name "$SM_SSH_HOST_NAME" "$SM_RESOURCE_TYPE" _export_ssh_key_env_var "$SM_SSH_HOST_NAME" From 6cb5e0fe4a6d7aff633880cc9474a464932a16a8 Mon Sep 17 00:00:00 2001 From: Elisha Ben Yosef Date: Mon, 2 Dec 2024 11:28:13 +0200 Subject: [PATCH 8/8] Add support for proxy command --- sagemaker_ssh_helper/wrapper.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sagemaker_ssh_helper/wrapper.py b/sagemaker_ssh_helper/wrapper.py index 12e3426..9c4ffba 100644 --- a/sagemaker_ssh_helper/wrapper.py +++ b/sagemaker_ssh_helper/wrapper.py @@ -183,7 +183,7 @@ def print_ssh_info(self): @classmethod def attach_to_resource(cls, fqdn: str, domain_id: str = '', - user_profile_name: str = '', + user_profile_or_space_name: str = '', sagemaker_session: Session = None): resource_type = SageMakerSecureShellHelper.fqdn_to_type(fqdn) resource_name = SageMakerSecureShellHelper.fqdn_to_name(fqdn) @@ -195,8 +195,8 @@ def attach_to_resource(cls, fqdn: str, return SSHProcessorWrapper.attach(resource_name, sagemaker_session) elif resource_type == 'transform': return SSHTransformerWrapper.attach(resource_name, sagemaker_session) - elif resource_type == 'ide': - return SSHIDEWrapper.attach(domain_id, user_profile_name, resource_name, sagemaker_session) + elif resource_type.endswith('ide'): + return SSHIDEWrapper.attach(domain_id, user_profile_or_space_name, resource_name, sagemaker_session) elif resource_type == 'notebook': return SSHNotebookInstanceWrapper.attach(resource_name, sagemaker_session) else: @@ -643,12 +643,12 @@ def __init__(self, self.not_earlier_than_timestamp = 0 @classmethod - def attach(cls, domain_id, user_profile_name, app_name, sagemaker_session: Session = None, + def attach(cls, domain_id, user_profile_or_space_name, app_name, sagemaker_session: Session = None, not_earlier_than_timestamp: int = 0) -> SSHIDEWrapper: sagemaker_session = sagemaker_session or sagemaker.Session() result = SSHIDEWrapper( '', - SSHIDE(domain_id, user_profile_name, sagemaker_session.boto_region_name), + SSHIDE(domain_id, user_profile_or_space_name, sagemaker_session.boto_region_name), connection_wait_time_seconds=0 ) result.app_name = app_name