diff --git a/src/sagemaker/remote_function/core/serialization.py b/src/sagemaker/remote_function/core/serialization.py index c794d0aac5..1bbc4bf734 100644 --- a/src/sagemaker/remote_function/core/serialization.py +++ b/src/sagemaker/remote_function/core/serialization.py @@ -15,7 +15,6 @@ import dataclasses import json -import os import sys import hmac import hashlib @@ -29,6 +28,8 @@ from tblib import pickling_support +# Note: do not use os.path.join for s3 uris, fails on windows + def _get_python_version(): return f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" @@ -143,18 +144,15 @@ def serialize_func_to_s3( Raises: SerializationError: when fail to serialize function to bytes. """ - bytes_to_upload = CloudpickleSerializer.serialize(func) - _upload_bytes_to_s3( - bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session - ) + _upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session) sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key) _upload_bytes_to_s3( _MetaData(sha256_hash).to_json(), - os.path.join(s3_uri, "metadata.json"), + f"{s3_uri}/metadata.json", s3_kms_key, sagemaker_session, ) @@ -177,20 +175,16 @@ def deserialize_func_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: DeserializationError: when fail to serialize function to bytes. """ metadata = _MetaData.from_json( - _read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session) + _read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session) ) - bytes_to_deserialize = _read_bytes_from_s3( - os.path.join(s3_uri, "payload.pkl"), sagemaker_session - ) + bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize ) - return CloudpickleSerializer.deserialize( - os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize - ) + return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) def serialize_obj_to_s3( @@ -211,15 +205,13 @@ def serialize_obj_to_s3( bytes_to_upload = CloudpickleSerializer.serialize(obj) - _upload_bytes_to_s3( - bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session - ) + _upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session) sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key) _upload_bytes_to_s3( _MetaData(sha256_hash).to_json(), - os.path.join(s3_uri, "metadata.json"), + f"{s3_uri}/metadata.json", s3_kms_key, sagemaker_session, ) @@ -240,20 +232,16 @@ def deserialize_obj_from_s3(sagemaker_session: Session, s3_uri: str, hmac_key: s """ metadata = _MetaData.from_json( - _read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session) + _read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session) ) - bytes_to_deserialize = _read_bytes_from_s3( - os.path.join(s3_uri, "payload.pkl"), sagemaker_session - ) + bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize ) - return CloudpickleSerializer.deserialize( - os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize - ) + return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) def serialize_exception_to_s3( @@ -275,15 +263,13 @@ def serialize_exception_to_s3( bytes_to_upload = CloudpickleSerializer.serialize(exc) - _upload_bytes_to_s3( - bytes_to_upload, os.path.join(s3_uri, "payload.pkl"), s3_kms_key, sagemaker_session - ) + _upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session) sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key) _upload_bytes_to_s3( _MetaData(sha256_hash).to_json(), - os.path.join(s3_uri, "metadata.json"), + f"{s3_uri}/metadata.json", s3_kms_key, sagemaker_session, ) @@ -304,20 +290,16 @@ def deserialize_exception_from_s3(sagemaker_session: Session, s3_uri: str, hmac_ """ metadata = _MetaData.from_json( - _read_bytes_from_s3(os.path.join(s3_uri, "metadata.json"), sagemaker_session) + _read_bytes_from_s3(f"{s3_uri}/metadata.json", sagemaker_session) ) - bytes_to_deserialize = _read_bytes_from_s3( - os.path.join(s3_uri, "payload.pkl"), sagemaker_session - ) + bytes_to_deserialize = _read_bytes_from_s3(f"{s3_uri}/payload.pkl", sagemaker_session) _perform_integrity_check( expected_hash_value=metadata.sha256_hash, secret_key=hmac_key, buffer=bytes_to_deserialize ) - return CloudpickleSerializer.deserialize( - os.path.join(s3_uri, "payload.pkl"), bytes_to_deserialize - ) + return CloudpickleSerializer.deserialize(f"{s3_uri}/payload.pkl", bytes_to_deserialize) def _upload_bytes_to_s3(bytes, s3_uri, s3_kms_key, sagemaker_session): diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index f2fe3f7c73..128dd33139 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -860,7 +860,7 @@ def _prepare_and_upload_runtime_scripts( ) shutil.copy2(spark_script_path, bootstrap_scripts) - with open(entrypoint_script_path, "w") as file: + with open(entrypoint_script_path, "w", newline="\n") as file: file.writelines(entry_point_script) bootstrap_script_path = os.path.join( diff --git a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py index 6655e1febf..aa002cceab 100644 --- a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -74,7 +74,7 @@ def _bootstrap_runtime_environment( Args: conda_env (str): conda environment to be activated. Default is None. """ - workspace_archive_dir_path = os.path.join(BASE_CHANNEL_PATH, REMOTE_FUNCTION_WORKSPACE) + workspace_archive_dir_path = f"{BASE_CHANNEL_PATH}/{REMOTE_FUNCTION_WORKSPACE}" if not os.path.exists(workspace_archive_dir_path): logger.info( @@ -84,7 +84,7 @@ def _bootstrap_runtime_environment( return # Unpack user workspace archive first. - workspace_archive_path = os.path.join(workspace_archive_dir_path, "workspace.zip") + workspace_archive_path = f"{workspace_archive_dir_path}/workspace.zip" if not os.path.isfile(workspace_archive_path): logger.info( "Workspace archive '%s' does not exist. Assuming no dependencies to bootstrap.", diff --git a/tests/unit/sagemaker/remote_function/core/test_serialization.py b/tests/unit/sagemaker/remote_function/core/test_serialization.py index 28f5b215e8..7774a9a8e1 100644 --- a/tests/unit/sagemaker/remote_function/core/test_serialization.py +++ b/tests/unit/sagemaker/remote_function/core/test_serialization.py @@ -12,7 +12,6 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import -import os.path import random import string import pytest @@ -186,7 +185,7 @@ def square(x): serialize_func_to_s3( func=square, sagemaker_session=Mock(), s3_uri=s3_uri, s3_kms_key=KMS_KEY, hmac_key=HMAC_KEY ) - mock_s3[os.path.join(s3_uri, "metadata.json")] = b"not json serializable" + mock_s3[f"{s3_uri}/metadata.json"] = b"not json serializable" del square diff --git a/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py b/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py index cf88775e49..712540e39a 100644 --- a/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py +++ b/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py @@ -199,7 +199,8 @@ def test_main_no_dependency_file( validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) path_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_DIR_PATH) file_exists.assert_called_once_with(TEST_WORKSPACE_ARCHIVE_PATH) - get_cwd.assert_called_once() + # Called twice by pathlib on some platforms + get_cwd.assert_called() list_dir.assert_called_once_with(pathlib.Path(TEST_DEPENDENCIES_PATH)) run_pre_exec_script.assert_called() bootstrap_runtime.assert_not_called()