diff --git a/sagemaker_ssh_helper/ide.py b/sagemaker_ssh_helper/ide.py index 6f05a8e..c7240e0 100644 --- a/sagemaker_ssh_helper/ide.py +++ b/sagemaker_ssh_helper/ide.py @@ -46,12 +46,13 @@ def __init__(self, arn, version_arn) -> None: class SSHIDE: logger = logging.getLogger('sagemaker-ssh-helper:SSHIDE') - def __init__(self, domain_id: str, user: str, region_name: str = None): - self.user = user + def __init__(self, domain_id: str, user_or_space: str = None, region_name: str = None, is_user_profile: bool = True): + self.user_or_space = user_or_space self.domain_id = domain_id self.current_region = region_name or boto3.session.Session().region_name self.client = boto3.client('sagemaker', region_name=self.current_region) self.ssh_log = SSHLog(region_name=self.current_region) + self.is_user_profile = is_user_profile def create_ssh_kernel_app(self, app_name: str, image_name_or_arn='sagemaker-datascience-38', @@ -108,13 +109,18 @@ def get_app_status(self, app_name: str, app_type: str = 'KernelGateway') -> IDEA :return: None | 'InService' | 'Deleted' | 'Deleting' | 'Failed' | 'Pending' """ response = None + + describe_app_request_params = { + "DomainId": self.domain_id, + "AppType": app_type, + "AppName": app_name, + } + + describe_app_request_params.update( + {"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space}) + try: - response = self.client.describe_app( - DomainId=self.domain_id, - AppType=app_type, - UserProfileName=self.user, - AppName=app_name, - ) + response = self.client.describe_app(**describe_app_request_params) except ClientError as e: error_code = e.response.get("Error", {}).get("Code") if error_code == 'ResourceNotFound': @@ -137,12 +143,16 @@ def delete_app(self, app_name, app_type, wait: bool = True): self.logger.info(f"Deleting app {app_name}") try: - _ = self.client.delete_app( - DomainId=self.domain_id, - AppType=app_type, - UserProfileName=self.user, - AppName=app_name, - ) + delete_app_request_params = { + "DomainId": self.domain_id, + "AppType": app_type, + "AppName": app_name, + } + + delete_app_request_params.update( + {"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space}) + + _ = self.client.delete_app(**delete_app_request_params) except ClientError as e: # probably, already deleted code = e.response.get("Error", {}).get("Code") @@ -173,13 +183,17 @@ def create_app(self, app_name, app_type, instance_type, image_arn, if lifecycle_arn: resource_spec['LifecycleConfigArn'] = lifecycle_arn - _ = self.client.create_app( - DomainId=self.domain_id, - AppType=app_type, - AppName=app_name, - UserProfileName=self.user, - ResourceSpec=resource_spec, - ) + create_app_request_params = { + "DomainId": self.domain_id, + "AppType": app_type, + "AppName": app_name, + "ResourceSpec": resource_spec, + } + + create_app_request_params.update( + {"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space}) + + _ = self.client.create_app(**create_app_request_params) status = self.get_app_status(app_name) while status.is_pending(): self.logger.info(f"Waiting for the InService status. Current status: {status}") @@ -199,45 +213,47 @@ def resolve_sagemaker_kernel_image_arn(self, image_name): sagemaker_account_id = "470317259841" # eu-west-1, TODO: check all images return f"arn:aws:sagemaker:{self.current_region}:{sagemaker_account_id}:image/{image_name}" - def print_kernel_instance_id(self, app_name, timeout_in_sec, index: int = 0): - print(self.get_kernel_instance_id(app_name, timeout_in_sec, index)) + def print_instance_id(self, app_name, timeout_in_sec, index: int = 0): + print(self.get_instance_id(app_name, timeout_in_sec, index)) - def get_kernel_instance_id(self, app_name, timeout_in_sec, index: int = 0, - not_earlier_than_timestamp: int = 0): - ids = self.get_kernel_instance_ids(app_name, timeout_in_sec, not_earlier_than_timestamp) + def get_instance_id(self, app_name, timeout_in_sec, index: int = 0, + not_earlier_than_timestamp: int = 0): + ids = self.get_instance_ids(app_name, timeout_in_sec, not_earlier_than_timestamp) if len(ids) == 0: - raise ValueError(f"No kernel instances found for app {app_name}") + raise ValueError(f"No instances found for app {app_name}") return ids[index] - def get_kernel_instance_ids(self, app_name: str, timeout_in_sec: int, not_earlier_than_timestamp: int = 0): - self.logger.info(f"Resolving IDE instance IDs for app '{app_name}' through SSM tags " - f"in domain '{self.domain_id}' for user '{self.user}'") + def get_instance_ids(self, app_name: str, timeout_in_sec: int, not_earlier_than_timestamp: int = 0): + self.logger.info(f"Resolving IDE instance IDs for app '{app_name}' through SSM tags in domain '{self.domain_id}' " + f"for {f'user' if self.is_user_profile else f'space'} '{self.user_or_space}'") self.log_urls(app_name) - if self.domain_id and self.user: - result = SSMManager().get_studio_user_kgw_instance_ids(self.domain_id, self.user, app_name, - timeout_in_sec, not_earlier_than_timestamp) - elif self.user: + + if self.domain_id and self.user_or_space: + result = SSMManager().get_studio_instance_ids(self.domain_id, self.user_or_space, app_name, + timeout_in_sec, not_earlier_than_timestamp, is_user_profile=self.is_user_profile) + elif self.user_or_space: self.logger.warning(f"Domain ID is not set. Will attempt to connect to the latest " - f"active kernel gateway with the name {app_name} in the region {self.current_region} " - f"for user profile {self.user}") - result = SSMManager().get_studio_user_kgw_instance_ids("", self.user, app_name, - timeout_in_sec, not_earlier_than_timestamp) + f"active {app_name} in the region {self.current_region} " + f"for {'user' if self.is_user_profile else 'space'} {self.user_or_space}") + result = SSMManager().get_studio_instance_ids("", self.user_or_space, app_name, + timeout_in_sec, not_earlier_than_timestamp, is_user_profile=self.is_user_profile) else: - self.logger.warning(f"Domain ID or user profile name are not set. Will attempt to connect to the latest " - f"active kernel gateway with the name {app_name} in the region {self.current_region}") + self.logger.warning( + f"Domain ID or {'user' if self.is_user_profile else 'space'} are not set. Will attempt to connect to the latest " + f"active {app_name} in the region {self.current_region}") result = SSMManager().get_studio_kgw_instance_ids(app_name, timeout_in_sec, not_earlier_than_timestamp) return result def log_urls(self, app_name): self.logger.info(f"Remote logs are at {self.get_cloudwatch_url(app_name)}") - if self.domain_id and self.user: - self.logger.info(f"Remote apps metadata is at {self.get_user_metadata_url()}") + if self.domain_id: + self.logger.info(f"Remote apps metadata is at {self.get_user_or_space_metadata_url()}") def get_cloudwatch_url(self, app_name): - return self.ssh_log.get_ide_cloudwatch_url(self.domain_id, self.user, app_name) + return self.ssh_log.get_ide_cloudwatch_url(self.domain_id, self.user_or_space, app_name, self.is_user_profile) - def get_user_metadata_url(self): - return self.ssh_log.get_ide_metadata_url(self.domain_id, self.user) + def get_user_or_space_metadata_url(self): + return self.ssh_log.get_ide_metadata_url(self.domain_id, self.user_or_space, self.is_user_profile) def create_and_attach_image(self, image_name, ecr_image_name, role_arn, diff --git a/sagemaker_ssh_helper/interactive_sagemaker.py b/sagemaker_ssh_helper/interactive_sagemaker.py index 73a48d2..8725ff2 100644 --- a/sagemaker_ssh_helper/interactive_sagemaker.py +++ b/sagemaker_ssh_helper/interactive_sagemaker.py @@ -31,15 +31,15 @@ def set_ping_status(self, ping_status): class SageMakerStudioApp(SageMakerCoreApp): - def __init__(self, domain_id: str, user_profile_name: str, app_name: str, app_type: str, - app_status: IDEAppStatus) -> None: + def __init__(self, domain_id: str, user_profile_or_space_name: str, app_name: str, app_type: str, app_status: IDEAppStatus, + is_user_profile: bool = True) -> None: super().__init__() self.app_status = app_status self.app_type = app_type self.app_name = app_name - self.user_profile_name = user_profile_name + self.user_profile_or_space_name = user_profile_or_space_name self.domain_id = domain_id - self.resource_type = "ide" + self.resource_type = "ide" if is_user_profile else "space-ide" def __str__(self) -> str: return "{0:<16} {1:<18} {2:<12} {5}.{4}.{3}.{6}".format( @@ -47,7 +47,7 @@ def __str__(self) -> str: self.app_type, str(self.app_status), self.domain_id, - self.user_profile_name, + self.user_profile_or_space_name, self.app_name, SageMakerSecureShellHelper.type_to_fqdn(self.resource_type) ) @@ -160,17 +160,17 @@ def list_ide_apps(self) -> List[SageMakerStudioApp]: app_name = app_dict['AppName'] app_type = app_dict['AppType'] if 'SpaceName' in app_dict: - logging.info("Don't support spaces: skipping app %s of type %s" % (app_name, app_type)) - pass + space_name = app_dict['SpaceName'] + logging.info("Found app %s of type %s for space %s" % (app_name, app_type, space_name)) + app_status = SSHIDE(domain_id, space_name, self.region, is_user_profile=False).get_app_status(app_name, app_type) + result.append(SageMakerStudioApp(domain_id, user_profile_or_space_name=space_name, app_name=app_dict['AppName'], + app_type=app_dict['AppType'], app_status=app_status, is_user_profile=False)) elif app_type in ['JupyterServer', 'KernelGateway']: user_profile_name = app_dict['UserProfileName'] logging.info("Found app %s of type %s for user %s" % (app_name, app_type, user_profile_name)) - app_status = SSHIDE(domain_id, user_profile_name, self.region).get_app_status(app_name, app_type) - result.append(SageMakerStudioApp( - domain_id, user_profile_name, - app_dict['AppName'], app_dict['AppType'], - app_status - )) + app_status = SSHIDE(domain_id, user_profile_name, self.region, is_user_profile=True).get_app_status(app_name, app_type) + result.append(SageMakerStudioApp(domain_id, user_profile_or_space_name=user_profile_name, app_name=app_dict['AppName'], + app_type=app_dict['AppType'], app_status=app_status, is_user_profile=True)) else: logging.info("Unsupported app type %s" % app_type) pass # We don't support other types like 'DetailedProfiler' @@ -290,14 +290,14 @@ def __init__(self, sagemaker: SageMaker, manager: SSMManager, self.manager = manager self.log = log - def list_studio_ide_apps_for_user_and_domain(self, domain_id: Optional[str], user_profile_name: Optional[str]): + def list_studio_ide_apps_for_user_or_space_and_domain(self, domain_id: Optional[str], user_profile_or_space_name: Optional[str]): managed_instances = self.manager.list_all_instances_and_fetch_tags() sagemaker_apps = self.sagemaker.list_ide_apps() result = [] for sagemaker_app in sagemaker_apps: if (sagemaker_app.domain_id == domain_id or domain_id is None or domain_id == "") \ - and (sagemaker_app.user_profile_name == user_profile_name or user_profile_name is None - or user_profile_name == ""): + and (sagemaker_app.user_profile_or_space_name == user_profile_or_space_name or user_profile_or_space_name is None + or user_profile_or_space_name == ""): instance_id = self._find_latest_app_instance_id(managed_instances, sagemaker_app) if instance_id: tags = managed_instances[instance_id] @@ -308,15 +308,15 @@ def list_studio_ide_apps_for_user_and_domain(self, domain_id: Optional[str], use return result def print_studio_ide_apps_for_user_and_domain(self, domain_id: str, user_profile_name: str): - apps: List[SageMakerStudioApp] = self.list_studio_ide_apps_for_user_and_domain(domain_id, user_profile_name) + apps: List[SageMakerStudioApp] = self.list_studio_ide_apps_for_user_or_space_and_domain(domain_id, user_profile_name) for app in apps: print(app) - def list_studio_ide_apps_for_user(self, user_profile_name: str): - return self.list_studio_ide_apps_for_user_and_domain(None, user_profile_name) + def list_studio_ide_apps_for_user_or_space(self, user_profile_or_space_name: str): + return self.list_studio_ide_apps_for_user_or_space_and_domain(None, user_profile_or_space_name) def list_studio_ide_apps(self): - return self.list_studio_ide_apps_for_user_and_domain(None, None) + return self.list_studio_ide_apps_for_user_or_space_and_domain(None, None) @staticmethod def _find_latest_instance_id(managed_instances: Dict[str, Dict[str, str]], @@ -342,7 +342,7 @@ def _find_latest_app_instance_id(managed_instances: Dict[str, Dict[str, str]], s arn = tags['SSHResourceArn'] if 'SSHResourceArn' in tags else '' timestamp = int(tags['SSHTimestamp']) if 'SSHTimestamp' in tags else 0 if (':app/' in arn and arn.endswith(f"/{sagemaker_app.app_name}") - and f"/{sagemaker_app.user_profile_name}/" in arn + and f"/{sagemaker_app.user_profile_or_space_name}/" in arn and f"/{sagemaker_app.domain_id}/" in arn and timestamp > max_timestamp): result = managed_instance_id diff --git a/sagemaker_ssh_helper/log.py b/sagemaker_ssh_helper/log.py index 0c13791..b21d5a1 100644 --- a/sagemaker_ssh_helper/log.py +++ b/sagemaker_ssh_helper/log.py @@ -197,22 +197,26 @@ def get_transform_metadata_url(self, transform_job_name): f"sagemaker/home?region={self.region_name}#" \ f"/transform-jobs/{transform_job_name}" - def get_ide_cloudwatch_url(self, domain, user, app_name): - app_type = 'JupyterServer' if app_name == 'default' else 'KernelGateway' - if user: + def get_ide_cloudwatch_url(self, domain, user_or_space, app_name, is_user_profile=True): + if is_user_profile: + app_type = 'JupyterServer' if app_name == 'default' else 'KernelGateway' + else: + app_type = 'JupyterLab' + if user_or_space: return f"https://{self.aws_console.get_console_domain()}/" \ f"cloudwatch/home?region={self.region_name}#" \ f"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252Fstudio" \ - f"$3FlogStreamNameFilter$3D{domain}$252F{user}$252F{app_type}$252F{app_name}" + f"$3FlogStreamNameFilter$3D{domain}$252F{user_or_space}$252F{app_type}$252F{app_name}" return f"https://{self.aws_console.get_console_domain()}/" \ f"cloudwatch/home?region={self.region_name}#" \ f"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252Fstudio" \ f"$3FlogStreamNameFilter$3D{app_type}$252F{app_name}" - def get_ide_metadata_url(self, domain, user): + def get_ide_metadata_url(self, domain, user_or_space, is_user_profile=True): + scope = 'user' if is_user_profile else 'space' return f"https://{self.aws_console.get_console_domain()}/" \ f"sagemaker/home?region={self.region_name}#" \ - f"/studio/{domain}/user/{user}" + f"/studio/{domain}/{scope}/{user_or_space}" def count_sns_notifications(self, topic_name: str, period: timedelta): cloudwatch_resource = boto3.resource('cloudwatch', region_name=self.region_name) diff --git a/sagemaker_ssh_helper/manager.py b/sagemaker_ssh_helper/manager.py index 5ea2aaa..988c3e5 100644 --- a/sagemaker_ssh_helper/manager.py +++ b/sagemaker_ssh_helper/manager.py @@ -113,14 +113,14 @@ def get_transformer_instance_ids(self, transform_job_name, timeout_in_sec=0): self.logger.info(f"Querying SSM instance IDs for transform job {transform_job_name}") return self.get_instance_ids('transform-job', transform_job_name, timeout_in_sec) - def get_studio_user_kgw_instance_ids(self, domain_id, user_profile_name, kgw_name, timeout_in_sec=0, - not_earlier_than_timestamp: int = 0): - self.logger.info(f"Querying SSM instance IDs for SageMaker Studio kernel gateway: '{kgw_name}'") + def get_studio_instance_ids(self, domain_id, user_profile_or_space_name, app_name, timeout_in_sec=0, not_earlier_than_timestamp: int = 0, is_user_profile=False): + self.logger.info(f"Querying SSM instance IDs for app '{app_name}' in SageMaker Studio {'kernel gateway' if is_user_profile else 'space'}: '{user_profile_or_space_name}'") if not domain_id: - arn_filter = f":app/.*/{user_profile_name}/" + arn_filter = f":app/.*/{user_profile_or_space_name}/" else: - arn_filter = f":app/{domain_id}/{user_profile_name}/" - return self.get_instance_ids('app', f"{kgw_name}", timeout_in_sec, + arn_filter = f":app/{domain_id}/{user_profile_or_space_name}/" + + return self.get_instance_ids('app', f"{app_name}", timeout_in_sec, arn_filter_regex=arn_filter, not_earlier_than_timestamp=not_earlier_than_timestamp) diff --git a/sagemaker_ssh_helper/sm-connect-ssh-proxy b/sagemaker_ssh_helper/sm-connect-ssh-proxy index f4cf70f..70742e9 100644 --- a/sagemaker_ssh_helper/sm-connect-ssh-proxy +++ b/sagemaker_ssh_helper/sm-connect-ssh-proxy @@ -94,7 +94,7 @@ send_command=$(aws ssm send-command \ 'cat /etc/ssh/authorized_keys.d/* > /etc/ssh/authorized_keys', 'ls -la /etc/ssh/authorized_keys' ]" \ - --no-cli-pager --no-paginate \ + --no-paginate \ --output json) json_value_regexp='s/^[^"]*".*": \"\(.*\)\"[^"]*/\1/' @@ -114,7 +114,7 @@ for i in $(seq 1 15); do command_output=$(aws ssm get-command-invocation \ --instance-id "${INSTANCE_ID}" \ --command-id "${command_id}" \ - --no-cli-pager --no-paginate \ + --no-paginate \ --output json) command_output=$(echo "$command_output" | $(_python) -m json.tool) command_status=$(echo "$command_output" | grep '"Status":' | sed -e "$json_value_regexp") @@ -166,7 +166,7 @@ proxy_command="aws ssm start-session\ --parameters portNumber=%p" # shellcheck disable=SC2086 -ssh -4 -o User=root -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \ +ssh -4 -o User=sagemaker-user -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \ -o ProxyCommand="$proxy_command" \ -o ConnectTimeout=90 \ -o ServerAliveInterval=15 -o ServerAliveCountMax=3 \ diff --git a/sagemaker_ssh_helper/sm-helper-functions b/sagemaker_ssh_helper/sm-helper-functions index 704b22d..e8f8090 100644 --- a/sagemaker_ssh_helper/sm-helper-functions +++ b/sagemaker_ssh_helper/sm-helper-functions @@ -194,9 +194,14 @@ function _print_sm_domain_id() { # shellcheck disable=SC2001 function _print_sm_user_profile_name() { - # FIXME: Check for "SpaceName" - spaces are not supported yet sm_resource_metadata_json=$(tr -d "\n" < /opt/ml/metadata/resource-metadata.json) - echo -n "$sm_resource_metadata_json" | sed -e 's/^.*"UserProfileName":\"\([^"]*\)\".*$/\1/' + echo -n "$sm_resource_metadata_json" | jq -r '.UserProfileName' +} + +# shellcheck disable=SC2001 +function _print_sm_space_name() { + sm_resource_metadata_json=$(tr -d "\n" < /opt/ml/metadata/resource-metadata.json) + echo -n "$sm_resource_metadata_json" | jq -r '.SpaceName' } function _print_sm_studio_python() { @@ -242,7 +247,7 @@ function _start_sshd() { if _is_centos; then /usr/sbin/sshd else - service ssh start || (echo "ERROR: Failed to start sshd service" && exit 255) + sudo service ssh start || (echo "ERROR: Failed to start sshd service" && exit 255) fi } diff --git a/sagemaker_ssh_helper/sm-init-ssm b/sagemaker_ssh_helper/sm-init-ssm index 9d6f4a0..6edc15c 100644 --- a/sagemaker_ssh_helper/sm-init-ssm +++ b/sagemaker_ssh_helper/sm-init-ssm @@ -59,4 +59,8 @@ response=$(aws ssm create-activation \ acode=$(echo $response | jq --raw-output '.ActivationCode') aid=$(echo $response | jq --raw-output '.ActivationId') -echo Yes | amazon-ssm-agent -register -id "$aid" -code "$acode" -region "$CURRENT_REGION" +if [[ -n $(_print_sm_user_profile_name) && $(_print_sm_user_profile_name) != "null" ]]; then + echo Yes | amazon-ssm-agent -register -id "$aid" -code "$acode" -region "$CURRENT_REGION" +else + echo Yes | sudo amazon-ssm-agent -register -id "$aid" -code "$acode" -region "$CURRENT_REGION" +fi diff --git a/sagemaker_ssh_helper/sm-local-ssh-ide b/sagemaker_ssh_helper/sm-local-ssh-ide index b3e6585..1c9e114 100644 --- a/sagemaker_ssh_helper/sm-local-ssh-ide +++ b/sagemaker_ssh_helper/sm-local-ssh-ide @@ -4,11 +4,11 @@ # Run `sm-ssh -h` for the high-level interface. # Commands: -# [--domain-id ] [--user-profile-name ] connect-app [--ssh-only] [] -# [--domain-id ] [--user-profile-name ] proxy-host +# [--domain-id ] [--user-profile-or-space-name ] connect-app [--ssh-only] [] +# [--domain-id ] [--user-profile-or-space-name ] proxy-host # run-command # set-domain-id -# set-user-profile-name +# set-user-profile-or-space-name # set-jb-license-server # SageMaker Studio app name for Kernel Gateways is usually the same as the hostname, @@ -16,7 +16,7 @@ # For JupyterServer app it's 'default'. # Tip: to open SageMaker Studio UI in Firefox from command line on macOS, use the following command: -# open -a Firefox $(AWS_PROFILE=terry aws sagemaker create-presigned-domain-url --domain-id d-lnwlaexample --user-profile-name terry-whitlock --query AuthorizedUrl --output text) +# open -a Firefox $(AWS_PROFILE=terry aws sagemaker create-presigned-domain-url --domain-id d-lnwlaexample --user-profile-or-space-name terry-whitlock --query AuthorizedUrl --output text) # replace with your JetBrains License Server host, or leave it as is if you don't use one @@ -38,7 +38,7 @@ source "$dir"/sm-helper-functions 2>/dev/null || source sm-helper-functions echo "sm-local-ssh-ide: Starting in $dir" DOMAIN_ID="" -USER_PROFILE_NAME="" +USER_PROFILE_OR_SPACE_NAME="" if [[ "$1" == "--domain-id" ]]; then DOMAIN_ID=$2 @@ -46,53 +46,61 @@ if [[ "$1" == "--domain-id" ]]; then shift fi -if [[ "$1" == "--user-profile-name" ]]; then - USER_PROFILE_NAME=$2 +if [[ "$1" == "--user-profile-or-space-name" ]]; then + USER_PROFILE_OR_SPACE_NAME=$2 shift shift fi COMMAND=$1 -if [[ "$COMMAND" != "set-user-profile-name" && "$COMMAND" != "set-domain-id" \ - && "$DOMAIN_ID" == "" && "$USER_PROFILE_NAME" == "" ]]; then +if [[ "$COMMAND" != "set-user-profile-or-space-name" && "$COMMAND" != "set-domain-id" \ + && "$DOMAIN_ID" == "" && "$USER_PROFILE_OR_SPACE_NAME" == "" ]]; then if [ -f ~/.sm-studio-domain-id ]; then DOMAIN_ID="$(cat ~/.sm-studio-domain-id)" else echo "sm-local-ssh-ide: WARNING: SageMaker Studio domain ID is not set."\ "Run 'sm-local-ssh-ide set-domain-id' to override." fi - if [ -f ~/.sm-studio-user-profile-name ]; then - USER_PROFILE_NAME="$(cat ~/.sm-studio-user-profile-name)" + if [ -f ~/.sm-studio-user-profile-or-space-name ]; then + USER_PROFILE_OR_SPACE_NAME="$(cat ~/.sm-studio-user-profile-or-space-name)" else echo "sm-local-ssh-ide: WARNING: SageMaker Studio user profile name is not set."\ - "Run 'sm-local-ssh-ide set-user-profile-name' to override." + "Run 'sm-local-ssh-ide set-user-profile-or-space-name' to override." fi fi if [[ "$COMMAND" == "proxy-host" ]]; then SM_SSH_HOST_NAME="$2" - SM_RESOURCE_TYPE="studio" + SM_RESOURCE_TYPE="space" _check_ssh_proxy_host_name "$SM_SSH_HOST_NAME" "$SM_RESOURCE_TYPE" _export_ssh_key_env_var "$SM_SSH_HOST_NAME" - INSTANCE_ID=$(_generate_key_and_print_instance_id "$SM_SSH_HOST_NAME" "$DOMAIN_ID" "$USER_PROFILE_NAME") + INSTANCE_ID=$(_generate_key_and_print_instance_id "$SM_SSH_HOST_NAME" "$DOMAIN_ID" "$USER_PROFILE_OR_SPACE_NAME") sm-local-start-ssh --proxy-setup-only "${INSTANCE_ID}" -elif [[ "$COMMAND" == "connect-app" ]]; then - - SM_STUDIO_KGW_NAME="$2" +elif [[ "$COMMAND" == connect* ]]; then + SM_STUDIO_APP_NAME="$2" OPTIONS="$3" + if [[ "$COMMAND" == "connect-app" ]]; then + IS_USER_PROFILE="True" + elif [[ "$COMMAND" == "connect-space-app" ]]; then + IS_USER_PROFILE="False" + else + echo "ERROR: Unknown command: '$COMMAND'" + exit 1 + fi + # shellcheck disable=SC2091 INSTANCE_ID=$($(_python) < ~/.sm-studio-domain-id -elif [[ "$COMMAND" == "set-user-profile-name" ]]; then - USER_PROFILE_NAME="$2" - if [[ "$USER_PROFILE_NAME" == "" ]]; then - echo "sm-local-ssh-ide: ERROR: argument is expected" +elif [[ "$COMMAND" == "set-user-profile-or-space-name" ]]; then + USER_PROFILE_OR_SPACE_NAME="$2" + if [[ "$USER_PROFILE_OR_SPACE_NAME" == "" ]]; then + echo "sm-local-ssh-ide: ERROR: argument is expected" exit 1 fi - echo "sm-local-ssh-ide: Saving SageMaker Studio user profile name into ~/.sm-studio-user-profile-name" - echo "$USER_PROFILE_NAME" > ~/.sm-studio-user-profile-name + echo "sm-local-ssh-ide: Saving SageMaker Studio user profile name into ~/.sm-studio-user-profile-or-space-name" + echo "$USER_PROFILE_OR_SPACE_NAME" > ~/.sm-studio-user-profile-or-space-name elif [[ "$COMMAND" == "run-command" ]]; then diff --git a/sagemaker_ssh_helper/sm-ssh-ide b/sagemaker_ssh_helper/sm-ssh-ide index 9b080c1..a45ac76 100644 --- a/sagemaker_ssh_helper/sm-ssh-ide +++ b/sagemaker_ssh_helper/sm-ssh-ide @@ -34,10 +34,10 @@ if [[ "$1" == "configure" ]]; then cat >/etc/profile.d/sm-ssh-ide.sh < str: if fqdn.endswith(".studio.sagemaker") or fqdn == "studio.sagemaker": return "ide" + elif fqdn.endswith(".space.sagemaker") or fqdn == "space.sagemaker": + return "space-ide" elif fqdn.endswith(".notebook.sagemaker") or fqdn == "notebook.sagemaker": return "notebook" elif fqdn.endswith(".training.sagemaker") or fqdn == "training.sagemaker": @@ -38,6 +40,8 @@ def fqdn_to_type(fqdn: str) -> str: def type_to_fqdn(cls, resource_type): if resource_type == "ide": return "studio.sagemaker" + elif resource_type == "space-ide": + return "space.sagemaker" elif resource_type == "notebook": return "notebook.sagemaker" elif resource_type == "training": @@ -68,7 +72,7 @@ def fqdn_to_studio_domain_id(fqdn: str) -> str: return '' @staticmethod - def fqdn_to_studio_user_name(fqdn: str) -> str: + def fqdn_to_studio_user_or_space_name(fqdn: str) -> str: if fqdn.count('.') == 4: return fqdn.split('.')[1] if fqdn.count('.') == 3: @@ -79,18 +83,20 @@ def fqdn_to_studio_user_name(fqdn: str) -> str: @classmethod def _get_arguments(cls, fqdn, resource, command): domain_id = "" - user_profile_name = "" - if resource == "ide": + user_profile_or_space_name = "" + + if resource.endswith("ide"): domain_id = SageMakerSecureShellHelper.fqdn_to_studio_domain_id(fqdn) - user_profile_name = SageMakerSecureShellHelper.fqdn_to_studio_user_name(fqdn) - if domain_id and user_profile_name: - arguments = ["bash", f"sm-local-ssh-{resource}", - "--domain-id", domain_id, "--user-profile-name", user_profile_name] + user_profile_or_space_name = SageMakerSecureShellHelper.fqdn_to_studio_user_or_space_name(fqdn) + if domain_id and user_profile_or_space_name: + arguments = ["bash", "sm-local-ssh-ide", "--domain-id", domain_id, "--user-profile-or-space-name", user_profile_or_space_name] else: arguments = ["bash", f"sm-local-ssh-{resource}"] if resource == "ide": arguments.append(cls._sm_ssh_command_to_local_ssh_command(command, "app")) + elif resource == "space-ide": + arguments.append(cls._sm_ssh_command_to_local_ssh_command(command, "space-app")) elif resource == "notebook": arguments.append(cls._sm_ssh_command_to_local_ssh_command(command, "notebook")) elif resource == "training": @@ -137,10 +143,11 @@ def list(self, fqdn): for resource in self.resources: if resource_type == resource or resource_type == "all": # if-then-else branch for every resource type: - if resource == "ide": + if resource.endswith("ide"): domain_id = SageMakerSecureShellHelper.fqdn_to_studio_domain_id(fqdn) - user_profile_name = SageMakerSecureShellHelper.fqdn_to_studio_user_name(fqdn) - interactive_sagemaker.print_studio_ide_apps_for_user_and_domain(domain_id, user_profile_name) + user_profile_or_space_name = SageMakerSecureShellHelper.fqdn_to_studio_user_or_space_name(fqdn) + interactive_sagemaker.print_studio_ide_apps_for_user_and_domain(domain_id, user_profile_or_space_name) + break elif resource == "notebook": interactive_sagemaker.print_notebook_instances() elif resource == "training": diff --git a/sagemaker_ssh_helper/wrapper.py b/sagemaker_ssh_helper/wrapper.py index a492da7..9c4ffba 100644 --- a/sagemaker_ssh_helper/wrapper.py +++ b/sagemaker_ssh_helper/wrapper.py @@ -183,7 +183,7 @@ def print_ssh_info(self): @classmethod def attach_to_resource(cls, fqdn: str, domain_id: str = '', - user_profile_name: str = '', + user_profile_or_space_name: str = '', sagemaker_session: Session = None): resource_type = SageMakerSecureShellHelper.fqdn_to_type(fqdn) resource_name = SageMakerSecureShellHelper.fqdn_to_name(fqdn) @@ -195,8 +195,8 @@ def attach_to_resource(cls, fqdn: str, return SSHProcessorWrapper.attach(resource_name, sagemaker_session) elif resource_type == 'transform': return SSHTransformerWrapper.attach(resource_name, sagemaker_session) - elif resource_type == 'ide': - return SSHIDEWrapper.attach(domain_id, user_profile_name, resource_name, sagemaker_session) + elif resource_type.endswith('ide'): + return SSHIDEWrapper.attach(domain_id, user_profile_or_space_name, resource_name, sagemaker_session) elif resource_type == 'notebook': return SSHNotebookInstanceWrapper.attach(resource_name, sagemaker_session) else: @@ -643,12 +643,12 @@ def __init__(self, self.not_earlier_than_timestamp = 0 @classmethod - def attach(cls, domain_id, user_profile_name, app_name, sagemaker_session: Session = None, + def attach(cls, domain_id, user_profile_or_space_name, app_name, sagemaker_session: Session = None, not_earlier_than_timestamp: int = 0) -> SSHIDEWrapper: sagemaker_session = sagemaker_session or sagemaker.Session() result = SSHIDEWrapper( '', - SSHIDE(domain_id, user_profile_name, sagemaker_session.boto_region_name), + SSHIDE(domain_id, user_profile_or_space_name, sagemaker_session.boto_region_name), connection_wait_time_seconds=0 ) result.app_name = app_name @@ -656,7 +656,7 @@ def attach(cls, domain_id, user_profile_name, app_name, sagemaker_session: Sessi return result def get_instance_ids(self, retry: int = None, timeout_in_sec: int = 900): - return self.ide.get_kernel_instance_ids(self.app_name, timeout_in_sec=timeout_in_sec, + return self.ide.get_instance_ids(self.app_name, timeout_in_sec=timeout_in_sec, not_earlier_than_timestamp=self.not_earlier_than_timestamp) def get_cloudwatch_url(self): diff --git a/tests/byoi_studio/Dockerfile.codeeditor.internet_free b/tests/byoi_studio/Dockerfile.codeeditor.internet_free new file mode 100644 index 0000000..a885e0f --- /dev/null +++ b/tests/byoi_studio/Dockerfile.codeeditor.internet_free @@ -0,0 +1,35 @@ +FROM public.ecr.aws/sagemaker/sagemaker-distribution:1.6-cpu@sha256:d5148872a9e35b62054fbd82991541592b0ea5edb7b343e579a2daf3b50c2f6b + +USER root +# Install SageMaker SSH Helper for the Internet-free setup +ARG SAGEMAKER_SSH_HELPER_DIR="/opt/sagemaker-ssh-helper" +RUN mkdir -p $SAGEMAKER_SSH_HELPER_DIR + +# See tests/test_ide.py::test_studio_internet_free_mode + +# Log the kernel specs +# The kernel name needs to match SageMaker Image config +# RUN jupyter-kernelspec list + +RUN pip3 uninstall -y -q awscli + +# Install official release (for users): +#RUN \ +# pip3 install --no-cache-dir sagemaker-ssh-helper + +# Install dev release from source (for developers): +COPY ./ $SAGEMAKER_SSH_HELPER_DIR/src/ +RUN \ + pip3 --no-cache-dir install wheel && \ + pip3 --no-cache-dir install $SAGEMAKER_SSH_HELPER_DIR/src/ + +# Pre-configure the container with packages, which should be installed from Internet +# Consider adding `--ssh-only` flag and commenting the first RUN command, if you don't plan to connect +# to the VNC server or to the Jupyter notebook +# RUN apt-get update -y && apt-get upgrade -y + +RUN sm-ssh-ide configure --ssh-only + +USER $MAMBA_USER +WORKDIR "/home/${NB_USER}" +ENTRYPOINT ["entrypoint-code-editor"] \ No newline at end of file diff --git a/tests/test_cli.py b/tests/test_cli.py index 5bfe289..f74fbc3 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -30,18 +30,18 @@ def test_fqdn_to_name(): def test_fqdn_to_studio_user_and_domain(): sm_ssh = SageMakerSecureShellHelper() assert sm_ssh.fqdn_to_studio_domain_id("ssh-training-job.training.sagemaker") == "" - assert sm_ssh.fqdn_to_studio_user_name("ssh-training-job.training.sagemaker") == "" + assert sm_ssh.fqdn_to_studio_user_or_space_name("ssh-training-job.training.sagemaker") == "" assert sm_ssh.fqdn_to_studio_domain_id("ssh-test-ds2-cpu.studio.sagemaker") == "" - assert sm_ssh.fqdn_to_studio_user_name("ssh-test-ds2-cpu.studio.sagemaker") == "" + assert sm_ssh.fqdn_to_studio_user_or_space_name("ssh-test-ds2-cpu.studio.sagemaker") == "" assert sm_ssh.fqdn_to_studio_domain_id( "ssh-test-ds2-cpu.test-data-science.d-egm0dexample.studio.sagemaker" ) == "d-egm0dexample" - assert sm_ssh.fqdn_to_studio_user_name( + assert sm_ssh.fqdn_to_studio_user_or_space_name( "ssh-test-ds2-cpu.test-data-science.d-egm0dexample.studio.sagemaker" ) == "test-data-science" assert sm_ssh.fqdn_to_studio_domain_id( "test-data-science.d-egm0dexample.studio.sagemaker" ) == "d-egm0dexample" - assert sm_ssh.fqdn_to_studio_user_name( + assert sm_ssh.fqdn_to_studio_user_or_space_name( "test-data-science.d-egm0dexample.studio.sagemaker" ) == "test-data-science" diff --git a/tests/test_ide.py b/tests/test_ide.py index 625a1c8..0325386 100644 --- a/tests/test_ide.py +++ b/tests/test_ide.py @@ -323,7 +323,7 @@ def test_studio_default_domain_multiple_users(request): @pytest.mark.parametrize('user_profile_name', ['test-firefox']) def test_studio_notebook_in_firefox(request, user_profile_name): - ide = SSHIDE(request.config.getini('sagemaker_studio_domain'), user_profile_name) + ide = SSHIDE(request.config.getini('sagemaker_studio_domain'), user_profile_name, is_user_profile=True) local_user_id = os.environ['LOCAL_USER_ID'] jb_server_host = os.environ['JB_LICENSE_SERVER_HOST'] @@ -371,7 +371,7 @@ def test_studio_notebook_in_firefox(request, user_profile_name): browser_automation.restart_kernel_and_run_all_cells() data_science_kernel = "sagemaker-data-science-ml-m5-large-6590da95dc67eec021b14bedc036" # noqa - studio_id = ide.get_kernel_instance_id( + studio_id = ide.get_instance_id( data_science_kernel, timeout_in_sec=300, not_earlier_than_timestamp=current_time_stamp diff --git a/tests/test_ssm_manager.py b/tests/test_ssm_manager.py index 0766925..84d8ef9 100644 --- a/tests/test_ssm_manager.py +++ b/tests/test_ssm_manager.py @@ -199,9 +199,10 @@ def test_can_filter_by_domain_and_user(): }, }) - ids = manager.get_studio_user_kgw_instance_ids( + ids = manager.get_studio_instance_ids( "d-0123456789bc", "default-1111111111111", - "sagemaker-data-science-ml-m5-large-1234567890abcdef0" + "sagemaker-data-science-ml-m5-large-1234567890abcdef0", + is_user_profile=True ) assert len(ids) == 1 assert ids[0] == "mi-01234567890abcd08" @@ -234,9 +235,10 @@ def test_can_filter_by_user_with_latest_domain(): }, }) - ids = manager.get_studio_user_kgw_instance_ids( + ids = manager.get_studio_instance_ids( "", "default-1111111111111", - "sagemaker-data-science-ml-m5-large-1234567890abcdef0" + "sagemaker-data-science-ml-m5-large-1234567890abcdef0", + is_user_profile=True ) assert len(ids) == 2 assert ids[0] == "mi-01234567890abcd08" @@ -310,38 +312,24 @@ def test_can_list_ssh_and_non_ssh_instances(): sagemaker = SageMaker('eu-west-1') sagemaker.list_ide_apps = Mock(return_value=[ - SageMakerStudioApp( - "d-0123456789bc", "janedoe", "default", "JupyterServer", IDEAppStatus("InService") - ), - SageMakerStudioApp( - "d-0123456789bc", "janedoe", "ssh-test-kgw", "KernelGateway", IDEAppStatus("InService") - ), - SageMakerStudioApp( - "d-0123456789bc", "janedoe", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("InService") - ), - SageMakerStudioApp( - "d-0123456789bc", "janedoe", "data-science-g4-no-ssh", "KernelGateway", IDEAppStatus("Offline") - ), - SageMakerStudioApp( - "d-0123456789bc", "terry", "sagemaker-data-science-ml-m5-large-1234567890abcdef0", "KernelGateway", IDEAppStatus("Offline") - ), - SageMakerStudioApp( - "d-0123456789bc", "terry", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("Offline") - ), - - SageMakerStudioApp( - "d-0123456789ab", "terry", "sagemaker-data-science-ml-m5-large-1234567890abcdef0", "KernelGateway", IDEAppStatus("Offline") - ), - SageMakerStudioApp( - "d-0123456789ab", "terry", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("Offline") - ), + SageMakerStudioApp("d-0123456789bc", "janedoe", "default", "JupyterServer", IDEAppStatus("InService")), + SageMakerStudioApp("d-0123456789bc", "janedoe", "ssh-test-kgw", "KernelGateway", IDEAppStatus("InService")), + SageMakerStudioApp("d-0123456789bc", "janedoe", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("InService")), + SageMakerStudioApp("d-0123456789bc", "janedoe", "data-science-g4-no-ssh", "KernelGateway", IDEAppStatus("Offline")), + SageMakerStudioApp("d-0123456789bc", "terry", "sagemaker-data-science-ml-m5-large-1234567890abcdef0", "KernelGateway", + IDEAppStatus("Offline")), + SageMakerStudioApp("d-0123456789bc", "terry", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("Offline")), + + SageMakerStudioApp("d-0123456789ab", "terry", "sagemaker-data-science-ml-m5-large-1234567890abcdef0", "KernelGateway", + IDEAppStatus("Offline")), + SageMakerStudioApp("d-0123456789ab", "terry", "data-science-m5-no-ssh", "KernelGateway", IDEAppStatus("Offline")), # LocalApp("janedoe", "AIDACKCEVSQ6C2EXAMPLE:janedoe@SSO", "macOS 13.5.1"), # LocalApp("terry", "AIDACKCEVSQ6C2EXAMPLE:terry@SSO", "Windows 10 Pro"), ]) interactive_sagemaker = InteractiveSageMaker(sagemaker, manager) - apps = interactive_sagemaker.list_studio_ide_apps_for_user_and_domain( + apps = interactive_sagemaker.list_studio_ide_apps_for_user_or_space_and_domain( "d-0123456789bc", "janedoe", ) assert len(apps) == 4 @@ -350,12 +338,12 @@ def test_can_list_ssh_and_non_ssh_instances(): assert apps[0].ssm_instance_id == "mi-01234567890abcd05" assert apps[0].ssh_owner == "AIDACKCEVSQ6C2EXAMPLE:janedoe@SSO" - apps = interactive_sagemaker.list_studio_ide_apps_for_user_and_domain( + apps = interactive_sagemaker.list_studio_ide_apps_for_user_or_space_and_domain( "d-0123456789bc", "terry", ) assert len(apps) == 2 - apps = interactive_sagemaker.list_studio_ide_apps_for_user( + apps = interactive_sagemaker.list_studio_ide_apps_for_user_or_space( "terry", ) assert len(apps) == 4