Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/sagemaker/remote_function/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,12 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any:
return cloudpickle.loads(bytes_to_deserialize)
except Exception as e:
raise DeserializationError(
"Error when deserializing bytes downloaded from {}: {}".format(s3_uri, repr(e))
"Error when deserializing bytes downloaded from {}: {}. "
"NOTE: this may be caused by inconsistent sagemaker python sdk versions "
"where remote function runs versus the one used on client side. "
"If the sagemaker versions do not match, a warning message would "
"be logged starting with 'Inconsistent sagemaker versions found'. "
"Please check it to validate.".format(s3_uri, repr(e))
) from e


Expand Down
6 changes: 6 additions & 0 deletions src/sagemaker/remote_function/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,12 @@ def compile(
container_args.extend(
["--client_python_version", RuntimeEnvironmentManager()._current_python_version()]
)
container_args.extend(
[
"--client_sagemaker_pysdk_version",
RuntimeEnvironmentManager()._current_sagemaker_pysdk_version(),
]
)
container_args.extend(
[
"--dependency_settings",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def main(sys_args=None):
try:
args = _parse_args(sys_args)
client_python_version = args.client_python_version
client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version
job_conda_env = args.job_conda_env
pipeline_execution_id = args.pipeline_execution_id
dependency_settings = _DependencySettings.from_string(args.dependency_settings)
Expand All @@ -64,6 +65,9 @@ def main(sys_args=None):
conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV")

RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env)
RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version(
client_sagemaker_pysdk_version
)

user = getpass.getuser()
if user != "root":
Expand Down Expand Up @@ -274,6 +278,7 @@ def _parse_args(sys_args):
parser = argparse.ArgumentParser()
parser.add_argument("--job_conda_env", type=str)
parser.add_argument("--client_python_version", type=str)
parser.add_argument("--client_sagemaker_pysdk_version", type=str, default=None)
parser.add_argument("--pipeline_execution_id", type=str)
parser.add_argument("--dependency_settings", type=str)
parser.add_argument("--func_step_s3_dir", type=str)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import dataclasses
import json

import sagemaker


class _UTCFormatter(logging.Formatter):
"""Class that overrides the default local time provider in log formatter."""
Expand Down Expand Up @@ -326,6 +328,11 @@ def _current_python_version(self):

return f"{sys.version_info.major}.{sys.version_info.minor}".strip()

def _current_sagemaker_pysdk_version(self):
"""Returns the current sagemaker python sdk version where program is running"""

return sagemaker.__version__

def _validate_python_version(self, client_python_version: str, conda_env: str = None):
"""Validate the python version

Expand All @@ -344,6 +351,29 @@ def _validate_python_version(self, client_python_version: str, conda_env: str =
f"is same as the local python version."
)

def _validate_sagemaker_pysdk_version(self, client_sagemaker_pysdk_version):
"""Validate the sagemaker python sdk version

Validates if the sagemaker python sdk version where remote function runs
matches the one used on client side.
Otherwise, log a warning to call out that unexpected behaviors
may occur in this case.
"""
job_sagemaker_pysdk_version = self._current_sagemaker_pysdk_version()
if (
client_sagemaker_pysdk_version
and client_sagemaker_pysdk_version != job_sagemaker_pysdk_version
):
logger.warning(
"Inconsistent sagemaker versions found: "
"sagemaker pysdk version found in the container is "
"'%s' which does not match the '%s' on the local client. "
"Please make sure that the python version used in the training container "
"is the same as the local python version in case of unexpected behaviors.",
job_sagemaker_pysdk_version,
client_sagemaker_pysdk_version,
)


def _run_and_get_output_shell_cmd(cmd: str) -> str:
"""Run and return the output of the given shell command"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def square(x):
with pytest.raises(
DeserializationError,
match=rf"Error when deserializing bytes downloaded from {s3_uri}/payload.pkl: "
+ r"RuntimeError\('some failure when loads'\)",
+ r"RuntimeError\('some failure when loads'\). "
+ r"NOTE: this may be caused by inconsistent sagemaker python sdk versions",
):
deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY)

Expand Down Expand Up @@ -397,7 +398,8 @@ def __init__(self, x):
with pytest.raises(
DeserializationError,
match=rf"Error when deserializing bytes downloaded from {s3_uri}/payload.pkl: "
+ r"RuntimeError\('some failure when loads'\)",
+ r"RuntimeError\('some failure when loads'\). "
+ r"NOTE: this may be caused by inconsistent sagemaker python sdk versions",
):
deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
CURR_WORKING_DIR = "/user/set/workdir"
TEST_DEPENDENCIES_PATH = "/user/set/workdir/sagemaker_remote_function_workspace"
TEST_PYTHON_VERSION = "3.10"
TEST_SAGEMAKER_PYSDK_VERSION = "2.205.0"
TEST_WORKSPACE_ARCHIVE_DIR_PATH = "/opt/ml/input/data/sm_rf_user_ws"
TEST_WORKSPACE_ARCHIVE_PATH = "/opt/ml/input/data/sm_rf_user_ws/workspace.zip"
TEST_EXECUTION_ID = "test_execution_id"
Expand All @@ -44,6 +45,8 @@ def args_for_remote():
TEST_JOB_CONDA_ENV,
"--client_python_version",
TEST_PYTHON_VERSION,
"--client_sagemaker_pysdk_version",
TEST_SAGEMAKER_PYSDK_VERSION,
"--dependency_settings",
_DependencySettings(TEST_DEPENDENCY_FILE_NAME).to_string(),
]
Expand All @@ -55,6 +58,8 @@ def args_for_step():
TEST_JOB_CONDA_ENV,
"--client_python_version",
TEST_PYTHON_VERSION,
"--client_sagemaker_pysdk_version",
TEST_SAGEMAKER_PYSDK_VERSION,
"--pipeline_execution_id",
TEST_EXECUTION_ID,
"--func_step_s3_dir",
Expand All @@ -63,6 +68,10 @@ def args_for_step():


@patch("sys.exit")
@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
)
@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_python_version"
Expand Down Expand Up @@ -90,12 +99,75 @@ def test_main_success_remote_job_with_root_user(
run_pre_exec_script,
bootstrap_runtime,
validate_python,
validate_sagemaker,
_exit_process,
):
bootstrap.main(args_for_remote())

change_dir_permission.assert_not_called()
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
bootstrap_remote.assert_called_once_with(
TEST_PYTHON_VERSION,
TEST_JOB_CONDA_ENV,
_DependencySettings(TEST_DEPENDENCY_FILE_NAME),
)
run_pre_exec_script.assert_not_called()
bootstrap_runtime.assert_not_called()
_exit_process.assert_called_with(0)


@patch("sys.exit")
@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
)
@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_python_version"
)
@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager.bootstrap"
)
@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager.run_pre_exec_script"
)
@patch(
"sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment."
"_bootstrap_runtime_env_for_remote_function"
)
@patch("getpass.getuser", MagicMock(return_value="root"))
@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager.change_dir_permission"
)
def test_main_success_with_obsoleted_args_that_missing_sagemaker_version(
change_dir_permission,
bootstrap_remote,
run_pre_exec_script,
bootstrap_runtime,
validate_python,
validate_sagemaker,
_exit_process,
):
# This test is to test the backward compatibility
# In old version of SDK, the client side sagemaker_pysdk_version is not passed to job
# thus it would be None and would not lead to the warning
obsoleted_args = [
"--job_conda_env",
TEST_JOB_CONDA_ENV,
"--client_python_version",
TEST_PYTHON_VERSION,
"--dependency_settings",
_DependencySettings(TEST_DEPENDENCY_FILE_NAME).to_string(),
]
bootstrap.main(obsoleted_args)

change_dir_permission.assert_not_called()
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
validate_sagemaker.assert_called_once_with(None)
bootstrap_remote.assert_called_once_with(
TEST_PYTHON_VERSION,
TEST_JOB_CONDA_ENV,
Expand All @@ -107,6 +179,10 @@ def test_main_success_remote_job_with_root_user(


@patch("sys.exit")
@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
)
@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_python_version"
Expand Down Expand Up @@ -134,11 +210,13 @@ def test_main_success_pipeline_step_with_root_user(
run_pre_exec_script,
bootstrap_runtime,
validate_python,
validate_sagemaker,
_exit_process,
):
bootstrap.main(args_for_step())
change_dir_permission.assert_not_called()
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
bootstrap_step.assert_called_once_with(
TEST_PYTHON_VERSION,
FUNC_STEP_WORKSPACE,
Expand All @@ -150,6 +228,10 @@ def test_main_success_pipeline_step_with_root_user(
_exit_process.assert_called_with(0)


@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
)
@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_python_version"
Expand Down Expand Up @@ -178,6 +260,7 @@ def test_main_failure_remote_job_with_root_user(
write_failure,
_exit_process,
validate_python,
validate_sagemaker,
):
runtime_err = RuntimeEnvironmentError("some failure reason")
bootstrap_runtime.side_effect = runtime_err
Expand All @@ -186,12 +269,17 @@ def test_main_failure_remote_job_with_root_user(

change_dir_permission.assert_not_called()
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
run_pre_exec_script.assert_not_called()
bootstrap_runtime.assert_called()
write_failure.assert_called_with(str(runtime_err))
_exit_process.assert_called_with(1)


@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
)
@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_python_version"
Expand Down Expand Up @@ -220,6 +308,7 @@ def test_main_failure_pipeline_step_with_root_user(
write_failure,
_exit_process,
validate_python,
validate_sagemaker,
):
runtime_err = RuntimeEnvironmentError("some failure reason")
bootstrap_runtime.side_effect = runtime_err
Expand All @@ -228,13 +317,18 @@ def test_main_failure_pipeline_step_with_root_user(

change_dir_permission.assert_not_called()
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
run_pre_exec_script.assert_not_called()
bootstrap_runtime.assert_called()
write_failure.assert_called_with(str(runtime_err))
_exit_process.assert_called_with(1)


@patch("sys.exit")
@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
)
@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_python_version"
Expand Down Expand Up @@ -262,6 +356,7 @@ def test_main_remote_job_with_non_root_user(
run_pre_exec_script,
bootstrap_runtime,
validate_python,
validate_sagemaker,
_exit_process,
):
bootstrap.main(args_for_remote())
Expand All @@ -270,6 +365,7 @@ def test_main_remote_job_with_non_root_user(
dirs=bootstrap.JOB_OUTPUT_DIRS, new_permission="777"
)
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
bootstrap_remote.assert_called_once_with(
TEST_PYTHON_VERSION,
TEST_JOB_CONDA_ENV,
Expand All @@ -281,6 +377,10 @@ def test_main_remote_job_with_non_root_user(


@patch("sys.exit")
@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_sagemaker_pysdk_version"
)
@patch(
"sagemaker.remote_function.runtime_environment.runtime_environment_manager."
"RuntimeEnvironmentManager._validate_python_version"
Expand Down Expand Up @@ -308,6 +408,7 @@ def test_main_pipeline_step_with_non_root_user(
run_pre_exec_script,
bootstrap_runtime,
validate_python,
validate_sagemaker,
_exit_process,
):
bootstrap.main(args_for_step())
Expand All @@ -316,6 +417,7 @@ def test_main_pipeline_step_with_non_root_user(
dirs=bootstrap.JOB_OUTPUT_DIRS, new_permission="777"
)
validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV)
validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION)
bootstrap_step.assert_called_once_with(
TEST_PYTHON_VERSION,
FUNC_STEP_WORKSPACE,
Expand Down
Loading