diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 5d6f25dbbb..38fd60d4e3 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -65,6 +65,8 @@ validate_source_code_input_against_pipeline_variables, ) from sagemaker.inputs import TrainingInput, FileSystemInput +from sagemaker.interactive_apps import SupportedInteractiveAppTypes +from sagemaker.interactive_apps.tensorboard import TensorBoardApp from sagemaker.instance_group import InstanceGroup from sagemaker.utils import instance_supports_kms from sagemaker.job import _Job @@ -750,6 +752,8 @@ def __init__( # Internal flag self._is_output_path_set_from_default_bucket_and_prefix = False + self.tensorboard_app = TensorBoardApp(region=self.sagemaker_session.boto_region_name) + @abstractmethod def training_image_uri(self): """Return the Docker image to use for training. @@ -2256,6 +2260,73 @@ def update_profiler( _TrainingJob.update(self, profiler_rule_configs, profiler_config_request_dict) + def get_app_url( + self, + app_type, + open_in_default_web_browser=True, + create_presigned_domain_url=False, + domain_id=None, + user_profile_name=None, + optional_create_presigned_url_kwargs=None, + ): + """Generate a URL to help access the specified app hosted in Amazon SageMaker Studio. + + Args: + app_type (str or SupportedInteractiveAppTypes): Required. The app type available in + SageMaker Studio to return a URL to. + open_in_default_web_browser (bool): Optional. When True, the URL will attempt to be + opened in the environment's default web browser. Otherwise, the resulting URL will + be returned by this function. + Default: ``True`` + create_presigned_domain_url (bool): Optional. Determines whether a presigned domain URL + should be generated instead of an unsigned URL. This only applies when called from + outside of a SageMaker Studio environment. If this is set to True inside of a + SageMaker Studio environment, it will be ignored. + Default: ``False`` + domain_id (str): Optional. The AWS Studio domain that the resulting app will use. If + code is executing in a Studio environment and this was not supplied, this will be + automatically detected. If not supplied and running in a non-Studio environment, it + is up to the derived class on how to handle that, but in general, a redirect to a + landing page can be expected. + Default: ``None`` + user_profile_name (str): Optional. The AWS Studio user profile that the resulting app + will use. If code is executing in a Studio environment and this was not supplied, + this will be automatically detected. If not supplied and running in a + non-Studio environment, it is up to the derived class on how to handle that, but in + general, a redirect to a landing page can be expected. + Default: ``None`` + optional_create_presigned_url_kwargs (dict): Optional. This parameter + should be passed when a user outside of Studio wants a presigned URL to the + TensorBoard application and wants to modify the optional parameters of the + create_presigned_domain_url call. + Default: ``None`` + Returns: + str: A URL for the requested app in SageMaker Studio. + """ + url = None + + # Get app_type in lower str format + if isinstance(app_type, SupportedInteractiveAppTypes): + app_type = app_type.name + app_type = app_type.lower() + + if app_type == SupportedInteractiveAppTypes.TENSORBOARD.name.lower(): + training_job_name = None + if self._current_job_name: + training_job_name = self._current_job_name + url = self.tensorboard_app.get_app_url( + training_job_name=training_job_name, + open_in_default_web_browser=open_in_default_web_browser, + create_presigned_domain_url=create_presigned_domain_url, + domain_id=domain_id, + user_profile_name=user_profile_name, + optional_create_presigned_url_kwargs=optional_create_presigned_url_kwargs, + ) + else: + raise ValueError(f"{app_type} does not support URL retrieval.") + + return url + class _TrainingJob(_Job): """Placeholder docstring""" diff --git a/src/sagemaker/interactive_apps/__init__.py b/src/sagemaker/interactive_apps/__init__.py index 97acb3a777..702d4c7d90 100644 --- a/src/sagemaker/interactive_apps/__init__.py +++ b/src/sagemaker/interactive_apps/__init__.py @@ -10,12 +10,21 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Classes for using debugger and profiler with Amazon SageMaker.""" +"""Classes for starting/accessing apps hosted on Amazon SageMaker Studio.""" + from __future__ import absolute_import +from enum import Enum + from sagemaker.interactive_apps.tensorboard import ( # noqa: F401 TensorBoardApp, ) from sagemaker.interactive_apps.detail_profiler_app import ( # noqa: F401 DetailProfilerApp, ) + + +class SupportedInteractiveAppTypes(Enum): + """SupportedInteractiveAppTypes indicates which apps are supported.""" + + TENSORBOARD = 1 diff --git a/src/sagemaker/interactive_apps/base_interactive_app.py b/src/sagemaker/interactive_apps/base_interactive_app.py new file mode 100644 index 0000000000..2004df54c9 --- /dev/null +++ b/src/sagemaker/interactive_apps/base_interactive_app.py @@ -0,0 +1,219 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""A base class for starting/accessing apps hosted on Amazon SageMaker Studio""" + +from __future__ import absolute_import + +import abc +import base64 +import json +import logging +import os +import re +import webbrowser + +from typing import Optional +import boto3 +from sagemaker.session import Session, NOTEBOOK_METADATA_FILE + +logger = logging.getLogger(__name__) + + +class BaseInteractiveApp(abc.ABC): + """BaseInteractiveApp is a base class for creating/accessing apps hosted on SageMaker.""" + + def __init__( + self, + region: Optional[str] = None, + ): + """Initialize a BaseInteractiveApp object. + + Args: + region (str): Optional. The AWS Region, e.g. us-east-1. If not specified, + one is created using the default AWS configuration chain. + Default: ``None`` + """ + if isinstance(region, str): + self.region = region + else: + try: + self.region = Session().boto_region_name + except ValueError: + raise ValueError( + "Failed to get the Region information from the default config. Please either " + "pass your Region manually as an input argument or set up the local AWS" + " configuration." + ) + + self._sagemaker_client = boto3.client("sagemaker", region_name=self.region) + # Used to store domain and user profile info retrieved from Studio environment. + self._domain_id = None + self._user_profile_name = None + self._get_domain_and_user() + + def __str__(self): + """Return str(self).""" + return f"{type(self).__name__}(region={self.region})" + + def __repr__(self): + """Return repr(self).""" + return self.__str__() + + def _get_domain_and_user(self): + """Get and validate studio domain id and user profile from studio environment.""" + if not self._is_in_studio(): + return + + try: + with open(NOTEBOOK_METADATA_FILE, "rb") as metadata_file: + metadata = json.loads(metadata_file.read()) + if not self._validate_domain_id( + metadata.get("DomainId") + ) or not self._validate_user_profile_name(metadata.get("UserProfileName")): + logger.warning( + "NOTEBOOK_METADATA_FILE detected but failed to get valid domain and user" + " from it." + ) + return + self._domain_id = metadata.get("DomainId") + self._user_profile_name = metadata.get("UserProfileName") + except OSError as err: + logger.warning("Could not load Studio metadata due to unexpected error. %s", err) + + def _get_presigned_url( + self, + create_presigned_url_kwargs: dict, + redirect: Optional[str] = None, + state: Optional[str] = None, + ): + """Generate a presigned URL to access a user's domain / user profile. + + Optional state and redirect parameters can be used to to have presigned URL automatically + redirect to a specific app and provide modifying data. + + Args: + create_presigned_url_kwargs (dict): Required. This dictionary should include the + parameters that will be used when calling create_presigned_domain_url via the boto3 + client. At a minimum, this should include the "DomainId" and "UserProfileName" + parameters as defined by create_presigned_domain_url's documentation. + Default: ``None`` + redirect (str): Optional. This value will be appended to the resulting presigned URL + in the format "&redirect=". This is used to automatically + redirect the user into a specific Studio app. + Default: ``None`` + state (str): Optional. This value will be appended to the resulting presigned URL + in the format "&state=". This is used to + automatically apply a state to the given app. Should be used in conjuction with + the redirect parameter. + Default: ``None`` + + Returns: + str: A presigned URL. + """ + response = self._sagemaker_client.create_presigned_domain_url(**create_presigned_url_kwargs) + if response["ResponseMetadata"]["HTTPStatusCode"] == 200: + url = response["AuthorizedUrl"] + else: + raise ValueError( + "An invalid status code was returned when creating a presigned URL." + f" See response for more: {response}" + ) + + if redirect: + url += f"&redirect={redirect}" + + if state: + url += f"&state={base64.b64encode(bytes(state, 'utf-8')).decode('utf-8')}" + + logger.warning( + "A presigned domain URL was generated. This is sensitive and should not be shared with" + " others." + ) + + return url + + def _is_in_studio(self): + """Check to see if NOTEBOOK_METADATA_FILE exists to verify Studio environment.""" + return os.path.isfile(NOTEBOOK_METADATA_FILE) + + def _open_url_in_web_browser(self, url: str): + """Open a URL in the default web browser. + + Args: + url (str): The URL to open. + """ + webbrowser.open(url) + + def _validate_domain_id(self, domain_id: Optional[str] = None): + """Validate domain id format. + + Args: + domain_id (str): Optional. The domain ID to validate. If one is not supplied, + self._domain_id will be used instead. + Default: ``None`` + + Returns: + bool: Whether the supplied domain ID is valid. + """ + if domain_id is None: + domain_id = self._domain_id + if domain_id is None or len(domain_id) > 63: + return False + return True + + def _validate_job_name(self, job_name: str): + """Validate training job name format. + + Args: + job_name (str): The job name to validate. + + Returns: + bool: Whether the supplied job name is valid. + """ + job_name_regex = "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}" + if not re.fullmatch(job_name_regex, job_name): + raise ValueError( + f"Invalid job name. Job name must match regular expression {job_name_regex}" + ) + + def _validate_user_profile_name(self, user_profile_name: Optional[str] = None): + """Validate user profile name format. + + Args: + user_profile_name (str): Optional. The user profile name to validate. If one is not + supplied, self._user_profile_name will be used instead. + Default: ``None`` + + Returns: + bool: Whether the supplied user profile name is valid. + """ + if user_profile_name is None: + user_profile_name = self._user_profile_name + user_profile_name_regex = "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}" + if user_profile_name is None or not re.fullmatch( + user_profile_name_regex, user_profile_name + ): + return False + return True + + def _validate_domain_and_user(self): + """Helper function to consolidate validation calls.""" + return self._validate_domain_id() and self._validate_user_profile_name() + + @abc.abstractmethod + def get_app_url(self): + """Abstract method to generate a URL to help access the application in Studio. + + Classes that inherit from BaseInteractiveApp should implement and override with what + parameters are needed for its specific use case. + """ diff --git a/src/sagemaker/interactive_apps/tensorboard.py b/src/sagemaker/interactive_apps/tensorboard.py index e2734de1ad..479883f04e 100644 --- a/src/sagemaker/interactive_apps/tensorboard.py +++ b/src/sagemaker/interactive_apps/tensorboard.py @@ -17,137 +17,136 @@ """ from __future__ import absolute_import -import json import logging -import os -import re from typing import Optional -from sagemaker.session import Session, NOTEBOOK_METADATA_FILE -logger = logging.getLogger(__name__) +from sagemaker.interactive_apps.base_interactive_app import BaseInteractiveApp +logger = logging.getLogger(__name__) -class TensorBoardApp(object): - """TensorBoardApp is a class for creating/accessing a TensorBoard app hosted on SageMaker.""" - def __init__(self, region: Optional[str] = None): - """Initialize a TensorBoardApp object. +class TensorBoardApp(BaseInteractiveApp): + """TensorBoardApp is a class for creating/accessing a TensorBoard app hosted on Studio.""" + + def get_app_url( + self, + training_job_name: Optional[str] = None, + open_in_default_web_browser: Optional[bool] = True, + create_presigned_domain_url: Optional[bool] = False, + domain_id: Optional[str] = None, + user_profile_name: Optional[str] = None, + optional_create_presigned_url_kwargs: Optional[dict] = None, + ): + """Generate a URL to help access the TensorBoard application hosted in Studio. + + For users that are already in SageMaker Studio, this method tries to get the + domain id and the user profile from the Studio environment. If successful, the generated + URL will direct to the TensorBoard application in SageMaker. Otherwise, it will direct + to the TensorBoard landing page in the SageMaker console. If a user outside of SageMaker + Studio passes in a valid domain ID and user profile name, the generated URL will be + presigned - authenticating the user and redirecting to the TensorBoard app once used. + Otherwise, the URL will direct to the TensorBoard landing page in the SageMaker console. + By default, the generated URL will attempt to open in the environment's default web + browser. Args: - region (str): The AWS Region, e.g. us-east-1. If not specified, - one is created using the default AWS configuration chain. - """ - if region: - self.region = region - else: - try: - self.region = Session().boto_region_name - except ValueError: - raise ValueError( - "Failed to get the Region information from the default config. Please either " - "pass your Region manually as an input argument or set up the local AWS " - "configuration." - ) + training_job_name (str): Optional. The name of the training job to pre-load in + TensorBoard. If nothing provided, the method just returns the TensorBoard + application URL. You can add training jobs later by using the SageMaker Data + Manager UI. + Default: ``None`` + open_in_default_web_browser (bool): Optional. When True, the URL will attempt to be + opened in the environment's default web browser. Otherwise, the resulting URL will + be returned by this function. + Default: ``True`` + create_presigned_domain_url (bool): Optional. Determines whether a presigned domain URL + should be generated instead of an unsigned URL. This only applies when called from + outside of a SageMaker Studio environment. If this is set to True inside of a + SageMaker Studio environment, it will be ignored. + Default: ``False`` + domain_id (str): Optional. This parameter should be passed when a user outside of + Studio wants a presigned URL to the TensorBoard application. This value will map to + 'DomainId' in the resulting create_presigned_domain_url call. Must be passed with + user_profile_name and create_presigned_domain_url set to True. + Default: ``None`` + user_profile_name (str): Optional. This parameter should be passed when a user outside + of Studio wants a presigned URL to the TensorBoard application. This value will + map to 'UserProfileName' in the resulting create_presigned_domain_url call. Must be + passed with domain_id and create_presigned_domain_url set to True. + Default: ``None`` + optional_create_presigned_url_kwargs (dict): Optional. This parameter + should be passed when a user outside of Studio wants a presigned URL to the + TensorBoard application and wants to modify the optional parameters of the + create_presigned_domain_url call. + Default: ``None`` - self._domain_id = None - self._user_profile_name = None - self._valid_domain_and_user = False - self._get_domain_and_user() + Returns: + str: A URL for TensorBoard hosted on SageMaker. + """ + if training_job_name is not None: + self._validate_job_name(training_job_name) - def __str__(self): - """Return str(self).""" - return f"TensorBoardApp(region={self.region})" + if optional_create_presigned_url_kwargs is None: + optional_create_presigned_url_kwargs = {} - def __repr__(self): - """Return repr(self).""" - return self.__str__() + if domain_id is not None: + optional_create_presigned_url_kwargs["DomainId"] = domain_id - def get_app_url(self, training_job_name: Optional[str] = None): - """Generates an unsigned URL to help access the TensorBoard application hosted in SageMaker. + if user_profile_name is not None: + optional_create_presigned_url_kwargs["UserProfileName"] = user_profile_name - For users that are already in SageMaker Studio, this method tries to get the domain id - and the user profile from the Studio environment. If succeeded, the generated URL will - direct to the TensorBoard application in SageMaker. Otherwise, it will direct to the - TensorBoard landing page in the SageMaker console. For non-Studio users, the URL will - direct to the TensorBoard landing page in the SageMaker console. - console. + if ( + create_presigned_domain_url + and not self._is_in_studio() + and self._validate_domain_id(optional_create_presigned_url_kwargs.get("DomainId")) + and self._validate_user_profile_name( + optional_create_presigned_url_kwargs.get("UserProfileName") + ) + ): + state_to_encode = None + redirect = "TensorBoard" - Args: - training_job_name (str): Optional. The name of the training job to pre-load in - TensorBoard. - If nothing provided, the method still returns the TensorBoard application URL, - but the application will not have any training jobs added for tracking. You can - add training jobs later by using the SageMaker Data Manager UI. - Default: ``None`` + if training_job_name is not None: + state_to_encode = ( + "/tensorboard/default/data/plugin/sagemaker_data_manager/" + + f"add_folder_or_job?Redirect=True&Name={training_job_name}" + ) - Returns: - str: An unsigned URL for TensorBoard hosted on SageMaker. - """ - if self._valid_domain_and_user: - url = "https://{}.studio.{}.sagemaker.aws/tensorboard/default".format( - self._domain_id, self.region + url = self._get_presigned_url( + optional_create_presigned_url_kwargs, redirect, state_to_encode + ) + elif self._is_in_studio() and self._validate_domain_and_user(): + if domain_id or user_profile_name: + logger.warning( + "Ignoring passed in domain_id and user_profile_name for Studio set values." + ) + url = ( + f"https://{self._domain_id}.studio.{self.region}." + + "sagemaker.aws/tensorboard/default" ) if training_job_name is not None: self._validate_job_name(training_job_name) url += ( - f"/data/plugin/sagemaker_data_manager/" - f"add_folder_or_job?Redirect=True&Name={training_job_name}" + "/data/plugin/sagemaker_data_manager/" + + f"add_folder_or_job?Redirect=True&Name={training_job_name}" ) else: url += "/#sagemaker_data_manager" else: - url = "https://{region}.console.aws.amazon.com/sagemaker/home?region={region}#/tensor-board-landing".format( - region=self.region - ) - if training_job_name is not None: - self._validate_job_name(training_job_name) - url += "/{}".format(training_job_name) - - return url - - def _get_domain_and_user(self): - """Get and validate studio domain id and user profile - - Get and validate studio domain id and user profile - from NOTEBOOK_METADATA_FILE in studio environment. - - Set _valid_domain_and_user to True if validation succeeded. - """ - if not os.path.isfile(NOTEBOOK_METADATA_FILE): - return - - with open(NOTEBOOK_METADATA_FILE, "rb") as f: - metadata = json.loads(f.read()) - self._domain_id = metadata.get("DomainId") - self._user_profile_name = metadata.get("UserProfileName") - if self._validate_domain_id() is True and self._validate_user_profile_name() is True: - self._valid_domain_and_user = True - else: + if domain_id or user_profile_name or create_presigned_domain_url: logger.warning( - "NOTEBOOK_METADATA_FILE detected but failed" - " to get valid domain and user from it." + "A valid domain ID and user profile name were not provided. " + "Providing default landing page URL as a result." ) - - def _validate_job_name(self, job_name: str): - """Validate training job name format.""" - job_name_regex = "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}" - if not re.fullmatch(job_name_regex, job_name): - raise ValueError( - "Invalid job name. Job name must match regular expression {}".format(job_name_regex) + url = ( + f"https://{self.region}.console.aws.amazon.com/sagemaker/home" + + f"?region={self.region}#/tensor-board-landing" ) + if training_job_name is not None: + url += f"/{training_job_name}" - def _validate_domain_id(self): - """Validate domain id format.""" - if self._domain_id is None or len(self._domain_id) > 63: - return False - return True - - def _validate_user_profile_name(self): - """Validate user profile name format.""" - user_profile_name_regex = "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}" - if self._user_profile_name is None or not re.fullmatch( - user_profile_name_regex, self._user_profile_name - ): - return False - return True + if open_in_default_web_browser: + self._open_url_in_web_browser(url) + url = "" + return url diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 5e1e4d2645..70f03ee43d 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -46,6 +46,7 @@ from sagemaker.fw_utils import PROFILER_UNSUPPORTED_REGIONS from sagemaker.inputs import ShuffleConfig from sagemaker.instance_group import InstanceGroup +from sagemaker.interactive_apps import SupportedInteractiveAppTypes from sagemaker.model import FrameworkModel from sagemaker.mxnet.estimator import MXNet from sagemaker.predictor import Predictor @@ -5448,3 +5449,46 @@ def without_user_input(sess): ), ) assert actual == expected + + +def test_estimator_get_app_url_success(sagemaker_session): + job_name = "get-app-url-test-job-name" + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + base_job_name=job_name, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + ) + f.fit("s3://mydata") + + url = f.get_app_url("TensorBoard", open_in_default_web_browser=False) + + assert url and job_name in url + + app_type = SupportedInteractiveAppTypes.TENSORBOARD + url = f.get_app_url(app_type, open_in_default_web_browser=False) + + assert url and job_name in url + + +def test_estimator_get_app_url_fail(sagemaker_session): + job_name = "get-app-url-test-job-name" + f = DummyFramework( + entry_point=SCRIPT_PATH, + role=ROLE, + sagemaker_session=sagemaker_session, + base_job_name=job_name, + instance_groups=[ + InstanceGroup("group1", "ml.c4.xlarge", 1), + InstanceGroup("group2", "ml.m4.xlarge", 2), + ], + ) + f.fit("s3://mydata") + with pytest.raises(ValueError) as error: + f.get_app_url("fake-app") + + assert "does not support URL retrieval." in str(error) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 7a31de9237..b07f90a55b 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -91,7 +91,10 @@ def test_default_session(boto3_default_session): @patch("boto3.DEFAULT_SESSION", None) @patch("boto3.Session") +@patch("boto3.DEFAULT_SESSION", None) def test_new_session_created(boto3_session): + # Need to have DEFAULT_SESSION return None as other unit tests can trigger creation of global + # default boto3 session that will persist and take precedence over boto3.Session() sess = Session() assert sess.boto_session is boto3_session.return_value diff --git a/tests/unit/test_tensorboard.py b/tests/unit/test_tensorboard.py index 44914b2d3f..0265377d08 100644 --- a/tests/unit/test_tensorboard.py +++ b/tests/unit/test_tensorboard.py @@ -10,18 +10,26 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +"""Tests related to TensorBoardApp""" from __future__ import absolute_import -from sagemaker.interactive_apps.tensorboard import TensorBoardApp -from unittest.mock import patch, mock_open, PropertyMock - import json +from unittest.mock import patch, Mock, mock_open, PropertyMock + +import boto3 +import botocore import pytest +from sagemaker.interactive_apps.tensorboard import TensorBoardApp + + TEST_DOMAIN = "testdomain" TEST_USER_PROFILE = "testuser" TEST_REGION = "testregion" TEST_NOTEBOOK_METADATA = json.dumps({"DomainId": TEST_DOMAIN, "UserProfileName": TEST_USER_PROFILE}) +TEST_PRESIGNED_URL = ( + f"https://{TEST_DOMAIN}.studio.{TEST_REGION}.sagemaker.aws/auth?token=FAKETOKEN" +) TEST_TRAINING_JOB = "testjob" BASE_URL_STUDIO_FORMAT = "https://{}.studio.{}.sagemaker.aws/tensorboard/default" @@ -34,81 +42,96 @@ REDIRECT_NON_STUDIO_FORMAT = "/{}" +@patch("boto3.client") @patch("os.path.isfile") -def test_tb_init_and_url_non_studio_user(mock_file_exists): +def test_tb_init_and_url_non_studio_user(mock_in_studio, mock_client): """ Test TensorBoardApp for non Studio users. """ - mock_file_exists.return_value = False + mock_in_studio.return_value = False + mock_client.return_value = boto3.client("sagemaker") tb_app = TensorBoardApp(TEST_REGION) assert tb_app.region == TEST_REGION assert tb_app._domain_id is None assert tb_app._user_profile_name is None - assert tb_app._valid_domain_and_user is False + assert tb_app._validate_domain_and_user() is False # test url without job redirect - assert tb_app.get_app_url() == BASE_URL_NON_STUDIO_FORMAT.format(region=TEST_REGION) + assert tb_app.get_app_url( + open_in_default_web_browser=False + ) == BASE_URL_NON_STUDIO_FORMAT.format(region=TEST_REGION) # test url with valid job redirect - assert tb_app.get_app_url(TEST_TRAINING_JOB) == BASE_URL_NON_STUDIO_FORMAT.format( - region=TEST_REGION - ) + REDIRECT_NON_STUDIO_FORMAT.format(TEST_TRAINING_JOB) + assert tb_app.get_app_url( + TEST_TRAINING_JOB, open_in_default_web_browser=False + ) == BASE_URL_NON_STUDIO_FORMAT.format(region=TEST_REGION) + REDIRECT_NON_STUDIO_FORMAT.format( + TEST_TRAINING_JOB + ) # test url with invalid job redirect with pytest.raises(ValueError): tb_app.get_app_url("invald_job_name!") -@patch("os.path.isfile") -def test_tb_init_and_url_studio_user_valid_medatada(mock_file_exists): +@patch("boto3.client") +@patch("sagemaker.interactive_apps.base_interactive_app.BaseInteractiveApp._is_in_studio") +def test_tb_init_and_url_studio_user_valid_medatada(mock_in_studio, mock_client): """ Test TensorBoardApp for Studio user when the notebook metadata file provided by Studio is valid. """ - mock_file_exists.return_value = True + mock_in_studio.return_value = True + mock_client.return_value = boto3.client("sagemaker") with patch("builtins.open", mock_open(read_data=TEST_NOTEBOOK_METADATA)): tb_app = TensorBoardApp(TEST_REGION) assert tb_app.region == TEST_REGION assert tb_app._domain_id == TEST_DOMAIN assert tb_app._user_profile_name == TEST_USER_PROFILE - assert tb_app._valid_domain_and_user is True + assert tb_app._validate_domain_and_user() is True # test url without job redirect assert ( - tb_app.get_app_url() + tb_app.get_app_url(open_in_default_web_browser=False) == BASE_URL_STUDIO_FORMAT.format(TEST_DOMAIN, TEST_REGION) + "/#sagemaker_data_manager" ) # test url with valid job redirect - assert tb_app.get_app_url(TEST_TRAINING_JOB) == BASE_URL_STUDIO_FORMAT.format( + assert tb_app.get_app_url( + TEST_TRAINING_JOB, open_in_default_web_browser=False + ) == BASE_URL_STUDIO_FORMAT.format( TEST_DOMAIN, TEST_REGION - ) + REDIRECT_STUDIO_FORMAT.format(TEST_TRAINING_JOB) + ) + REDIRECT_STUDIO_FORMAT.format( + TEST_TRAINING_JOB + ) # test url with invalid job redirect with pytest.raises(ValueError): tb_app.get_app_url("invald_job_name!") -@patch("os.path.isfile") -def test_tb_init_and_url_studio_user_invalid_medatada(mock_file_exists): +@patch("boto3.client") +@patch("sagemaker.interactive_apps.base_interactive_app.BaseInteractiveApp._is_in_studio") +def test_tb_init_and_url_studio_user_invalid_medatada(mock_in_studio, mock_client): """ - Test TensorBoardApp for Studio user when the notebook metadata file provided by Studio is invalid. + Test TensorBoardApp for Amazon SageMaker Studio user when the notebook metadata file provided + by Studio is invalid. """ - mock_file_exists.return_value = True + mock_in_studio.return_value = True + mock_client.return_value = boto3.client("sagemaker") # test file does not contain domain and user profle with patch("builtins.open", mock_open(read_data=json.dumps({"Fake": "Fake"}))): - assert TensorBoardApp(TEST_REGION).get_app_url() == BASE_URL_NON_STUDIO_FORMAT.format( - region=TEST_REGION - ) + assert TensorBoardApp(TEST_REGION).get_app_url( + open_in_default_web_browser=False + ) == BASE_URL_NON_STUDIO_FORMAT.format(region=TEST_REGION) # test invalid user profile name with patch( "builtins.open", mock_open(read_data=json.dumps({"DomainId": TEST_DOMAIN, "UserProfileName": "u" * 64})), ): - assert TensorBoardApp(TEST_REGION).get_app_url() == BASE_URL_NON_STUDIO_FORMAT.format( - region=TEST_REGION - ) + assert TensorBoardApp(TEST_REGION).get_app_url( + open_in_default_web_browser=False + ) == BASE_URL_NON_STUDIO_FORMAT.format(region=TEST_REGION) # test invalid domain id with patch( @@ -117,9 +140,134 @@ def test_tb_init_and_url_studio_user_invalid_medatada(mock_file_exists): read_data=json.dumps({"DomainId": "d" * 64, "UserProfileName": TEST_USER_PROFILE}) ), ): - assert TensorBoardApp(TEST_REGION).get_app_url() == BASE_URL_NON_STUDIO_FORMAT.format( - region=TEST_REGION + assert TensorBoardApp(TEST_REGION).get_app_url( + open_in_default_web_browser=False + ) == BASE_URL_NON_STUDIO_FORMAT.format(region=TEST_REGION) + + +@patch("boto3.client") +@patch("sagemaker.interactive_apps.base_interactive_app.BaseInteractiveApp._is_in_studio") +def test_tb_presigned_url_success(mock_in_studio, mock_client): + mock_in_studio.return_value = False + resp = { + "ResponseMetadata": {"HTTPStatusCode": 200}, + "AuthorizedUrl": TEST_PRESIGNED_URL, + } + attrs = {"create_presigned_domain_url.return_value": resp} + mock_client.return_value = Mock(**attrs) + + url = TensorBoardApp(TEST_REGION).get_app_url( + domain_id=TEST_DOMAIN, + user_profile_name=TEST_USER_PROFILE, + create_presigned_domain_url=True, + open_in_default_web_browser=False, + ) + assert url == f"{TEST_PRESIGNED_URL}&redirect=TensorBoard" + + url = TensorBoardApp(TEST_REGION).get_app_url( + training_job_name=TEST_TRAINING_JOB, + domain_id=TEST_DOMAIN, + user_profile_name=TEST_USER_PROFILE, + create_presigned_domain_url=True, + open_in_default_web_browser=False, + ) + assert url.startswith(f"{TEST_PRESIGNED_URL}&redirect=TensorBoard&state=") + assert url.endswith("==") + + +@patch("boto3.client") +@patch("sagemaker.interactive_apps.base_interactive_app.BaseInteractiveApp._is_in_studio") +def test_tb_presigned_url_success_open_in_web_browser(mock_in_studio, mock_client): + mock_in_studio.return_value = False + resp = { + "ResponseMetadata": {"HTTPStatusCode": 200}, + "AuthorizedUrl": TEST_PRESIGNED_URL, + } + attrs = {"create_presigned_domain_url.return_value": resp} + mock_client.return_value = Mock(**attrs) + + with patch("webbrowser.open") as mock_web_browser_open: + url = TensorBoardApp(TEST_REGION).get_app_url( + domain_id=TEST_DOMAIN, + user_profile_name=TEST_USER_PROFILE, + create_presigned_domain_url=True, + open_in_default_web_browser=True, ) + mock_web_browser_open.assert_called_with(f"{TEST_PRESIGNED_URL}&redirect=TensorBoard") + assert url == "" + + +@patch("boto3.client") +@patch("sagemaker.interactive_apps.base_interactive_app.BaseInteractiveApp._is_in_studio") +def test_tb_presigned_url_not_returned_without_presigned_flag(mock_in_studio, mock_client): + mock_in_studio.return_value = False + mock_client.return_value = boto3.client("sagemaker") + + url = TensorBoardApp(TEST_REGION).get_app_url( + domain_id=TEST_DOMAIN, + user_profile_name=TEST_USER_PROFILE, + create_presigned_domain_url=False, + open_in_default_web_browser=False, + ) + assert url == BASE_URL_NON_STUDIO_FORMAT.format(region=TEST_REGION) + + +@patch("boto3.client") +@patch("sagemaker.interactive_apps.base_interactive_app.BaseInteractiveApp._is_in_studio") +def test_tb_presigned_url_failure(mock_in_studio, mock_client): + mock_in_studio.return_value = False + resp = {"ResponseMetadata": {"HTTPStatusCode": 400}} + attrs = {"create_presigned_domain_url.return_value": resp} + mock_client.return_value = Mock(**attrs) + + with pytest.raises(ValueError): + TensorBoardApp(TEST_REGION).get_app_url( + domain_id=TEST_DOMAIN, + user_profile_name=TEST_USER_PROFILE, + create_presigned_domain_url=True, + open_in_default_web_browser=False, + ) + + +@patch("sagemaker.interactive_apps.base_interactive_app.BaseInteractiveApp._is_in_studio") +def test_tb_invalid_presigned_kwargs(mock_in_studio): + mock_in_studio.return_value = False + invalid_kwargs = { + "fake-parameter": True, + "DomainId": TEST_DOMAIN, + "UserProfileName": TEST_USER_PROFILE, + } + + with pytest.raises(botocore.exceptions.ParamValidationError): + TensorBoardApp(TEST_REGION).get_app_url( + optional_create_presigned_url_kwargs=invalid_kwargs, + create_presigned_domain_url=True, + ) + + +@patch("boto3.client") +@patch("sagemaker.interactive_apps.base_interactive_app.BaseInteractiveApp._is_in_studio") +def test_tb_valid_presigned_kwargs(mock_in_studio, mock_client): + mock_in_studio.return_value = False + + rsp = { + "ResponseMetadata": {"HTTPStatusCode": 200}, + "AuthorizedUrl": TEST_PRESIGNED_URL, + } + mock_client = boto3.client("sagemaker") + mock_client.create_presigned_domain_url = Mock(name="create_presigned_domain_url") + mock_client.create_presigned_domain_url.return_value = rsp + + valid_kwargs = {"DomainId": TEST_DOMAIN, "UserProfileName": TEST_USER_PROFILE} + + url = TensorBoardApp(TEST_REGION).get_app_url( + optional_create_presigned_url_kwargs=valid_kwargs, + create_presigned_domain_url=True, + open_in_default_web_browser=False, + ) + + assert url == f"{TEST_PRESIGNED_URL}&redirect=TensorBoard" + mock_client.create_presigned_domain_url.assert_called_once_with(**valid_kwargs) def test_tb_init_with_default_region():