Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 61 additions & 45 deletions sagemaker_ssh_helper/ide.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,13 @@ def __init__(self, arn, version_arn) -> None:
class SSHIDE:
logger = logging.getLogger('sagemaker-ssh-helper:SSHIDE')

def __init__(self, domain_id: str, user: str, region_name: str = None):
self.user = user
def __init__(self, domain_id: str, user_or_space: str = None, region_name: str = None, is_user_profile: bool = True):
self.user_or_space = user_or_space
self.domain_id = domain_id
self.current_region = region_name or boto3.session.Session().region_name
self.client = boto3.client('sagemaker', region_name=self.current_region)
self.ssh_log = SSHLog(region_name=self.current_region)
self.is_user_profile = is_user_profile

def create_ssh_kernel_app(self, app_name: str,
image_name_or_arn='sagemaker-datascience-38',
Expand Down Expand Up @@ -108,13 +109,18 @@ def get_app_status(self, app_name: str, app_type: str = 'KernelGateway') -> IDEA
:return: None | 'InService' | 'Deleted' | 'Deleting' | 'Failed' | 'Pending'
"""
response = None

describe_app_request_params = {
"DomainId": self.domain_id,
"AppType": app_type,
"AppName": app_name,
}

describe_app_request_params.update(
{"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space})

try:
response = self.client.describe_app(
DomainId=self.domain_id,
AppType=app_type,
UserProfileName=self.user,
AppName=app_name,
)
response = self.client.describe_app(**describe_app_request_params)
except ClientError as e:
error_code = e.response.get("Error", {}).get("Code")
if error_code == 'ResourceNotFound':
Expand All @@ -137,12 +143,16 @@ def delete_app(self, app_name, app_type, wait: bool = True):
self.logger.info(f"Deleting app {app_name}")

try:
_ = self.client.delete_app(
DomainId=self.domain_id,
AppType=app_type,
UserProfileName=self.user,
AppName=app_name,
)
delete_app_request_params = {
"DomainId": self.domain_id,
"AppType": app_type,
"AppName": app_name,
}

delete_app_request_params.update(
{"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space})

_ = self.client.delete_app(**delete_app_request_params)
except ClientError as e:
# probably, already deleted
code = e.response.get("Error", {}).get("Code")
Expand Down Expand Up @@ -173,13 +183,17 @@ def create_app(self, app_name, app_type, instance_type, image_arn,
if lifecycle_arn:
resource_spec['LifecycleConfigArn'] = lifecycle_arn

_ = self.client.create_app(
DomainId=self.domain_id,
AppType=app_type,
AppName=app_name,
UserProfileName=self.user,
ResourceSpec=resource_spec,
)
create_app_request_params = {
"DomainId": self.domain_id,
"AppType": app_type,
"AppName": app_name,
"ResourceSpec": resource_spec,
}

create_app_request_params.update(
{"UserProfileName": self.user_or_space} if self.is_user_profile else {"SpaceName": self.user_or_space})

_ = self.client.create_app(**create_app_request_params)
status = self.get_app_status(app_name)
while status.is_pending():
self.logger.info(f"Waiting for the InService status. Current status: {status}")
Expand All @@ -199,45 +213,47 @@ def resolve_sagemaker_kernel_image_arn(self, image_name):
sagemaker_account_id = "470317259841" # eu-west-1, TODO: check all images
return f"arn:aws:sagemaker:{self.current_region}:{sagemaker_account_id}:image/{image_name}"

def print_kernel_instance_id(self, app_name, timeout_in_sec, index: int = 0):
print(self.get_kernel_instance_id(app_name, timeout_in_sec, index))
def print_instance_id(self, app_name, timeout_in_sec, index: int = 0):
print(self.get_instance_id(app_name, timeout_in_sec, index))

def get_kernel_instance_id(self, app_name, timeout_in_sec, index: int = 0,
not_earlier_than_timestamp: int = 0):
ids = self.get_kernel_instance_ids(app_name, timeout_in_sec, not_earlier_than_timestamp)
def get_instance_id(self, app_name, timeout_in_sec, index: int = 0,
not_earlier_than_timestamp: int = 0):
ids = self.get_instance_ids(app_name, timeout_in_sec, not_earlier_than_timestamp)
if len(ids) == 0:
raise ValueError(f"No kernel instances found for app {app_name}")
raise ValueError(f"No instances found for app {app_name}")
return ids[index]

def get_kernel_instance_ids(self, app_name: str, timeout_in_sec: int, not_earlier_than_timestamp: int = 0):
self.logger.info(f"Resolving IDE instance IDs for app '{app_name}' through SSM tags "
f"in domain '{self.domain_id}' for user '{self.user}'")
def get_instance_ids(self, app_name: str, timeout_in_sec: int, not_earlier_than_timestamp: int = 0):
self.logger.info(f"Resolving IDE instance IDs for app '{app_name}' through SSM tags in domain '{self.domain_id}' "
f"for {f'user' if self.is_user_profile else f'space'} '{self.user_or_space}'")
self.log_urls(app_name)
if self.domain_id and self.user:
result = SSMManager().get_studio_user_kgw_instance_ids(self.domain_id, self.user, app_name,
timeout_in_sec, not_earlier_than_timestamp)
elif self.user:

if self.domain_id and self.user_or_space:
result = SSMManager().get_studio_instance_ids(self.domain_id, self.user_or_space, app_name,
timeout_in_sec, not_earlier_than_timestamp, is_user_profile=self.is_user_profile)
elif self.user_or_space:
self.logger.warning(f"Domain ID is not set. Will attempt to connect to the latest "
f"active kernel gateway with the name {app_name} in the region {self.current_region} "
f"for user profile {self.user}")
result = SSMManager().get_studio_user_kgw_instance_ids("", self.user, app_name,
timeout_in_sec, not_earlier_than_timestamp)
f"active {app_name} in the region {self.current_region} "
f"for {'user' if self.is_user_profile else 'space'} {self.user_or_space}")
result = SSMManager().get_studio_instance_ids("", self.user_or_space, app_name,
timeout_in_sec, not_earlier_than_timestamp, is_user_profile=self.is_user_profile)
else:
self.logger.warning(f"Domain ID or user profile name are not set. Will attempt to connect to the latest "
f"active kernel gateway with the name {app_name} in the region {self.current_region}")
self.logger.warning(
f"Domain ID or {'user' if self.is_user_profile else 'space'} are not set. Will attempt to connect to the latest "
f"active {app_name} in the region {self.current_region}")
result = SSMManager().get_studio_kgw_instance_ids(app_name, timeout_in_sec, not_earlier_than_timestamp)
return result

def log_urls(self, app_name):
self.logger.info(f"Remote logs are at {self.get_cloudwatch_url(app_name)}")
if self.domain_id and self.user:
self.logger.info(f"Remote apps metadata is at {self.get_user_metadata_url()}")
if self.domain_id:
self.logger.info(f"Remote apps metadata is at {self.get_user_or_space_metadata_url()}")

def get_cloudwatch_url(self, app_name):
return self.ssh_log.get_ide_cloudwatch_url(self.domain_id, self.user, app_name)
return self.ssh_log.get_ide_cloudwatch_url(self.domain_id, self.user_or_space, app_name, self.is_user_profile)

def get_user_metadata_url(self):
return self.ssh_log.get_ide_metadata_url(self.domain_id, self.user)
def get_user_or_space_metadata_url(self):
return self.ssh_log.get_ide_metadata_url(self.domain_id, self.user_or_space, self.is_user_profile)

def create_and_attach_image(self, image_name, ecr_image_name,
role_arn,
Expand Down
42 changes: 21 additions & 21 deletions sagemaker_ssh_helper/interactive_sagemaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,23 @@ def set_ping_status(self, ping_status):


class SageMakerStudioApp(SageMakerCoreApp):
def __init__(self, domain_id: str, user_profile_name: str, app_name: str, app_type: str,
app_status: IDEAppStatus) -> None:
def __init__(self, domain_id: str, user_profile_or_space_name: str, app_name: str, app_type: str, app_status: IDEAppStatus,
is_user_profile: bool = True) -> None:
super().__init__()
self.app_status = app_status
self.app_type = app_type
self.app_name = app_name
self.user_profile_name = user_profile_name
self.user_profile_or_space_name = user_profile_or_space_name
self.domain_id = domain_id
self.resource_type = "ide"
self.resource_type = "ide" if is_user_profile else "space-ide"

def __str__(self) -> str:
return "{0:<16} {1:<18} {2:<12} {5}.{4}.{3}.{6}".format(
self.ping_status if self.ssm_instance_id else self.NO_SSH_FLAG,
self.app_type,
str(self.app_status),
self.domain_id,
self.user_profile_name,
self.user_profile_or_space_name,
self.app_name,
SageMakerSecureShellHelper.type_to_fqdn(self.resource_type)
)
Expand Down Expand Up @@ -160,17 +160,17 @@ def list_ide_apps(self) -> List[SageMakerStudioApp]:
app_name = app_dict['AppName']
app_type = app_dict['AppType']
if 'SpaceName' in app_dict:
logging.info("Don't support spaces: skipping app %s of type %s" % (app_name, app_type))
pass
space_name = app_dict['SpaceName']
logging.info("Found app %s of type %s for space %s" % (app_name, app_type, space_name))
app_status = SSHIDE(domain_id, space_name, self.region, is_user_profile=False).get_app_status(app_name, app_type)
result.append(SageMakerStudioApp(domain_id, user_profile_or_space_name=space_name, app_name=app_dict['AppName'],
app_type=app_dict['AppType'], app_status=app_status, is_user_profile=False))
elif app_type in ['JupyterServer', 'KernelGateway']:
user_profile_name = app_dict['UserProfileName']
logging.info("Found app %s of type %s for user %s" % (app_name, app_type, user_profile_name))
app_status = SSHIDE(domain_id, user_profile_name, self.region).get_app_status(app_name, app_type)
result.append(SageMakerStudioApp(
domain_id, user_profile_name,
app_dict['AppName'], app_dict['AppType'],
app_status
))
app_status = SSHIDE(domain_id, user_profile_name, self.region, is_user_profile=True).get_app_status(app_name, app_type)
result.append(SageMakerStudioApp(domain_id, user_profile_or_space_name=user_profile_name, app_name=app_dict['AppName'],
app_type=app_dict['AppType'], app_status=app_status, is_user_profile=True))
else:
logging.info("Unsupported app type %s" % app_type)
pass # We don't support other types like 'DetailedProfiler'
Expand Down Expand Up @@ -290,14 +290,14 @@ def __init__(self, sagemaker: SageMaker, manager: SSMManager,
self.manager = manager
self.log = log

def list_studio_ide_apps_for_user_and_domain(self, domain_id: Optional[str], user_profile_name: Optional[str]):
def list_studio_ide_apps_for_user_or_space_and_domain(self, domain_id: Optional[str], user_profile_or_space_name: Optional[str]):
managed_instances = self.manager.list_all_instances_and_fetch_tags()
sagemaker_apps = self.sagemaker.list_ide_apps()
result = []
for sagemaker_app in sagemaker_apps:
if (sagemaker_app.domain_id == domain_id or domain_id is None or domain_id == "") \
and (sagemaker_app.user_profile_name == user_profile_name or user_profile_name is None
or user_profile_name == ""):
and (sagemaker_app.user_profile_or_space_name == user_profile_or_space_name or user_profile_or_space_name is None
or user_profile_or_space_name == ""):
instance_id = self._find_latest_app_instance_id(managed_instances, sagemaker_app)
if instance_id:
tags = managed_instances[instance_id]
Expand All @@ -308,15 +308,15 @@ def list_studio_ide_apps_for_user_and_domain(self, domain_id: Optional[str], use
return result

def print_studio_ide_apps_for_user_and_domain(self, domain_id: str, user_profile_name: str):
apps: List[SageMakerStudioApp] = self.list_studio_ide_apps_for_user_and_domain(domain_id, user_profile_name)
apps: List[SageMakerStudioApp] = self.list_studio_ide_apps_for_user_or_space_and_domain(domain_id, user_profile_name)
for app in apps:
print(app)

def list_studio_ide_apps_for_user(self, user_profile_name: str):
return self.list_studio_ide_apps_for_user_and_domain(None, user_profile_name)
def list_studio_ide_apps_for_user_or_space(self, user_profile_or_space_name: str):
return self.list_studio_ide_apps_for_user_or_space_and_domain(None, user_profile_or_space_name)

def list_studio_ide_apps(self):
return self.list_studio_ide_apps_for_user_and_domain(None, None)
return self.list_studio_ide_apps_for_user_or_space_and_domain(None, None)

@staticmethod
def _find_latest_instance_id(managed_instances: Dict[str, Dict[str, str]],
Expand All @@ -342,7 +342,7 @@ def _find_latest_app_instance_id(managed_instances: Dict[str, Dict[str, str]], s
arn = tags['SSHResourceArn'] if 'SSHResourceArn' in tags else ''
timestamp = int(tags['SSHTimestamp']) if 'SSHTimestamp' in tags else 0
if (':app/' in arn and arn.endswith(f"/{sagemaker_app.app_name}")
and f"/{sagemaker_app.user_profile_name}/" in arn
and f"/{sagemaker_app.user_profile_or_space_name}/" in arn
and f"/{sagemaker_app.domain_id}/" in arn
and timestamp > max_timestamp):
result = managed_instance_id
Expand Down
16 changes: 10 additions & 6 deletions sagemaker_ssh_helper/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,22 +197,26 @@ def get_transform_metadata_url(self, transform_job_name):
f"sagemaker/home?region={self.region_name}#" \
f"/transform-jobs/{transform_job_name}"

def get_ide_cloudwatch_url(self, domain, user, app_name):
app_type = 'JupyterServer' if app_name == 'default' else 'KernelGateway'
if user:
def get_ide_cloudwatch_url(self, domain, user_or_space, app_name, is_user_profile=True):
if is_user_profile:
app_type = 'JupyterServer' if app_name == 'default' else 'KernelGateway'
else:
app_type = 'JupyterLab'
if user_or_space:
return f"https://{self.aws_console.get_console_domain()}/" \
f"cloudwatch/home?region={self.region_name}#" \
f"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252Fstudio" \
f"$3FlogStreamNameFilter$3D{domain}$252F{user}$252F{app_type}$252F{app_name}"
f"$3FlogStreamNameFilter$3D{domain}$252F{user_or_space}$252F{app_type}$252F{app_name}"
return f"https://{self.aws_console.get_console_domain()}/" \
f"cloudwatch/home?region={self.region_name}#" \
f"logsV2:log-groups/log-group/$252Faws$252Fsagemaker$252Fstudio" \
f"$3FlogStreamNameFilter$3D{app_type}$252F{app_name}"

def get_ide_metadata_url(self, domain, user):
def get_ide_metadata_url(self, domain, user_or_space, is_user_profile=True):
scope = 'user' if is_user_profile else 'space'
return f"https://{self.aws_console.get_console_domain()}/" \
f"sagemaker/home?region={self.region_name}#" \
f"/studio/{domain}/user/{user}"
f"/studio/{domain}/{scope}/{user_or_space}"

def count_sns_notifications(self, topic_name: str, period: timedelta):
cloudwatch_resource = boto3.resource('cloudwatch', region_name=self.region_name)
Expand Down
12 changes: 6 additions & 6 deletions sagemaker_ssh_helper/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,14 @@ def get_transformer_instance_ids(self, transform_job_name, timeout_in_sec=0):
self.logger.info(f"Querying SSM instance IDs for transform job {transform_job_name}")
return self.get_instance_ids('transform-job', transform_job_name, timeout_in_sec)

def get_studio_user_kgw_instance_ids(self, domain_id, user_profile_name, kgw_name, timeout_in_sec=0,
not_earlier_than_timestamp: int = 0):
self.logger.info(f"Querying SSM instance IDs for SageMaker Studio kernel gateway: '{kgw_name}'")
def get_studio_instance_ids(self, domain_id, user_profile_or_space_name, app_name, timeout_in_sec=0, not_earlier_than_timestamp: int = 0, is_user_profile=False):
self.logger.info(f"Querying SSM instance IDs for app '{app_name}' in SageMaker Studio {'kernel gateway' if is_user_profile else 'space'}: '{user_profile_or_space_name}'")
if not domain_id:
arn_filter = f":app/.*/{user_profile_name}/"
arn_filter = f":app/.*/{user_profile_or_space_name}/"
else:
arn_filter = f":app/{domain_id}/{user_profile_name}/"
return self.get_instance_ids('app', f"{kgw_name}", timeout_in_sec,
arn_filter = f":app/{domain_id}/{user_profile_or_space_name}/"

return self.get_instance_ids('app', f"{app_name}", timeout_in_sec,
arn_filter_regex=arn_filter,
not_earlier_than_timestamp=not_earlier_than_timestamp)

Expand Down
6 changes: 3 additions & 3 deletions sagemaker_ssh_helper/sm-connect-ssh-proxy
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ send_command=$(aws ssm send-command \
'cat /etc/ssh/authorized_keys.d/* > /etc/ssh/authorized_keys',
'ls -la /etc/ssh/authorized_keys'
]" \
--no-cli-pager --no-paginate \
--no-paginate \
--output json)

json_value_regexp='s/^[^"]*".*": \"\(.*\)\"[^"]*/\1/'
Expand All @@ -114,7 +114,7 @@ for i in $(seq 1 15); do
command_output=$(aws ssm get-command-invocation \
--instance-id "${INSTANCE_ID}" \
--command-id "${command_id}" \
--no-cli-pager --no-paginate \
--no-paginate \
--output json)
command_output=$(echo "$command_output" | $(_python) -m json.tool)
command_status=$(echo "$command_output" | grep '"Status":' | sed -e "$json_value_regexp")
Expand Down Expand Up @@ -166,7 +166,7 @@ proxy_command="aws ssm start-session\
--parameters portNumber=%p"

# shellcheck disable=SC2086
ssh -4 -o User=root -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \
ssh -4 -o User=sagemaker-user -o IdentityFile="${SSH_KEY}" -o IdentitiesOnly=yes \
-o ProxyCommand="$proxy_command" \
-o ConnectTimeout=90 \
-o ServerAliveInterval=15 -o ServerAliveCountMax=3 \
Expand Down
Loading