diff --git a/src/sagemaker/remote_function/core/pipeline_variables.py b/src/sagemaker/remote_function/core/pipeline_variables.py index 2b18e6c50c..269ce94113 100644 --- a/src/sagemaker/remote_function/core/pipeline_variables.py +++ b/src/sagemaker/remote_function/core/pipeline_variables.py @@ -77,6 +77,17 @@ class _ExecutionVariable: name: str +@dataclass +class _S3BaseUriIdentifier: + """Identifies that the class refers to function step s3 base uri. + + The s3_base_uri = s3_root_uri + pipeline_name. + This identifier is resolved in function step runtime by SDK. + """ + + NAME = "S3_BASE_URI" + + @dataclass class _DelayedReturn: """Delayed return from a function.""" @@ -155,6 +166,7 @@ def __init__( hmac_key: str, parameter_resolver: _ParameterResolver, execution_variable_resolver: _ExecutionVariableResolver, + s3_base_uri: str, **settings, ): """Resolve delayed return. @@ -164,8 +176,12 @@ def __init__( hmac_key: key used to encrypt serialized and deserialized function and arguments. parameter_resolver: resolver used to pipeline parameters. execution_variable_resolver: resolver used to resolve execution variables. + s3_base_uri (str): the s3 base uri of the function step that + the serialized artifacts will be uploaded to. + The s3_base_uri = s3_root_uri + pipeline_name. **settings: settings to pass to the deserialization function. """ + self._s3_base_uri = s3_base_uri self._parameter_resolver = parameter_resolver self._execution_variable_resolver = execution_variable_resolver # different delayed returns can have the same uri, so we need to dedupe @@ -205,6 +221,8 @@ def _resolve_delayed_return_uri(self, delayed_return: _DelayedReturn): uri.append(self._parameter_resolver.resolve(component)) elif isinstance(component, _ExecutionVariable): uri.append(self._execution_variable_resolver.resolve(component)) + elif isinstance(component, _S3BaseUriIdentifier): + uri.append(self._s3_base_uri) else: uri.append(component) return s3_path_join(*uri) @@ -219,7 +237,12 @@ def _retrieve_child_item(delayed_return: _DelayedReturn, deserialized_obj: Any): def resolve_pipeline_variables( - context: Context, func_args: Tuple, func_kwargs: Dict, hmac_key: str, **settings + context: Context, + func_args: Tuple, + func_kwargs: Dict, + hmac_key: str, + s3_base_uri: str, + **settings, ): """Resolve pipeline variables. @@ -228,6 +251,8 @@ def resolve_pipeline_variables( func_args: function args. func_kwargs: function kwargs. hmac_key: key used to encrypt serialized and deserialized function and arguments. + s3_base_uri: the s3 base uri of the function step that the serialized artifacts + will be uploaded to. The s3_base_uri = s3_root_uri + pipeline_name. **settings: settings to pass to the deserialization function. """ @@ -251,6 +276,7 @@ def resolve_pipeline_variables( hmac_key=hmac_key, parameter_resolver=parameter_resolver, execution_variable_resolver=execution_variable_resolver, + s3_base_uri=s3_base_uri, **settings, ) @@ -289,11 +315,10 @@ def resolve_pipeline_variables( return resolved_func_args, resolved_func_kwargs -def convert_pipeline_variables_to_pickleable(s3_base_uri: str, func_args: Tuple, func_kwargs: Dict): +def convert_pipeline_variables_to_pickleable(func_args: Tuple, func_kwargs: Dict): """Convert pipeline variables to pickleable. Args: - s3_base_uri: s3 base uri where artifacts are stored. func_args: function args. func_kwargs: function kwargs. """ @@ -304,11 +329,19 @@ def convert_pipeline_variables_to_pickleable(s3_base_uri: str, func_args: Tuple, from sagemaker.workflow.function_step import DelayedReturn + # Notes: + # 1. The s3_base_uri = s3_root_uri + pipeline_name, but the two may be unknown + # when defining function steps. After step-level arg serialization, + # it's hard to update the s3_base_uri in pipeline compile time. + # Thus set a placeholder: _S3BaseUriIdentifier, and let the runtime job to resolve it. + # 2. For saying s3_root_uri is unknown, it's because when defining function steps, + # the pipeline's sagemaker_session is not passed in, but the default s3_root_uri + # should be retrieved from the pipeline's sagemaker_session. def convert(arg): if isinstance(arg, DelayedReturn): return _DelayedReturn( uri=[ - s3_base_uri, + _S3BaseUriIdentifier(), ExecutionVariables.PIPELINE_EXECUTION_ID._pickleable, arg._step.name, "results", diff --git a/src/sagemaker/remote_function/core/serialization.py b/src/sagemaker/remote_function/core/serialization.py index 943e89636d..821744ee6b 100644 --- a/src/sagemaker/remote_function/core/serialization.py +++ b/src/sagemaker/remote_function/core/serialization.py @@ -161,17 +161,13 @@ def serialize_func_to_s3( Raises: SerializationError: when fail to serialize function to bytes. """ - bytes_to_upload = CloudpickleSerializer.serialize(func) - _upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session) - - sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key) - - _upload_bytes_to_s3( - _MetaData(sha256_hash).to_json(), - f"{s3_uri}/metadata.json", - s3_kms_key, - sagemaker_session, + _upload_payload_and_metadata_to_s3( + bytes_to_upload=CloudpickleSerializer.serialize(func), + hmac_key=hmac_key, + s3_uri=s3_uri, + sagemaker_session=sagemaker_session, + s3_kms_key=s3_kms_key, ) @@ -220,17 +216,12 @@ def serialize_obj_to_s3( SerializationError: when fail to serialize object to bytes. """ - bytes_to_upload = CloudpickleSerializer.serialize(obj) - - _upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session) - - sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key) - - _upload_bytes_to_s3( - _MetaData(sha256_hash).to_json(), - f"{s3_uri}/metadata.json", - s3_kms_key, - sagemaker_session, + _upload_payload_and_metadata_to_s3( + bytes_to_upload=CloudpickleSerializer.serialize(obj), + hmac_key=hmac_key, + s3_uri=s3_uri, + sagemaker_session=sagemaker_session, + s3_kms_key=s3_kms_key, ) @@ -318,8 +309,32 @@ def serialize_exception_to_s3( """ pickling_support.install() - bytes_to_upload = CloudpickleSerializer.serialize(exc) + _upload_payload_and_metadata_to_s3( + bytes_to_upload=CloudpickleSerializer.serialize(exc), + hmac_key=hmac_key, + s3_uri=s3_uri, + sagemaker_session=sagemaker_session, + s3_kms_key=s3_kms_key, + ) + +def _upload_payload_and_metadata_to_s3( + bytes_to_upload: Union[bytes, io.BytesIO], + hmac_key: str, + s3_uri: str, + sagemaker_session: Session, + s3_kms_key, +): + """Uploads serialized payload and metadata to s3. + + Args: + bytes_to_upload (bytes): Serialized bytes to upload. + hmac_key (str): Key used to calculate hmac-sha256 hash of the serialized obj. + s3_uri (str): S3 root uri to which resulting serialized artifacts will be uploaded. + sagemaker_session (sagemaker.session.Session): + The underlying Boto3 session which AWS service calls are delegated to. + s3_kms_key (str): KMS key used to encrypt artifacts uploaded to S3. + """ _upload_bytes_to_s3(bytes_to_upload, f"{s3_uri}/payload.pkl", s3_kms_key, sagemaker_session) sha256_hash = _compute_hash(bytes_to_upload, secret_key=hmac_key) diff --git a/src/sagemaker/remote_function/core/stored_function.py b/src/sagemaker/remote_function/core/stored_function.py index 52b9e33936..862c67d9ee 100644 --- a/src/sagemaker/remote_function/core/stored_function.py +++ b/src/sagemaker/remote_function/core/stored_function.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import os +from dataclasses import dataclass from typing import Any @@ -36,6 +37,14 @@ JSON_RESULTS_FILE = "results.json" +@dataclass +class _SerializedData: + """Data class to store serialized function and arguments""" + + func: bytes + args: bytes + + class StoredFunction: """Class representing a remote function stored in S3.""" @@ -105,6 +114,38 @@ def save(self, func, *args, **kwargs): s3_kms_key=self.s3_kms_key, ) + def save_pipeline_step_function(self, serialized_data): + """Upload serialized function and arguments to s3. + + Args: + serialized_data (_SerializedData): The serialized function + and function arguments of a function step. + """ + + logger.info( + "Uploading serialized function code to %s", + s3_path_join(self.func_upload_path, FUNCTION_FOLDER), + ) + serialization._upload_payload_and_metadata_to_s3( + bytes_to_upload=serialized_data.func, + hmac_key=self.hmac_key, + s3_uri=s3_path_join(self.func_upload_path, FUNCTION_FOLDER), + sagemaker_session=self.sagemaker_session, + s3_kms_key=self.s3_kms_key, + ) + + logger.info( + "Uploading serialized function arguments to %s", + s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), + ) + serialization._upload_payload_and_metadata_to_s3( + bytes_to_upload=serialized_data.args, + hmac_key=self.hmac_key, + s3_uri=s3_path_join(self.func_upload_path, ARGUMENTS_FOLDER), + sagemaker_session=self.sagemaker_session, + s3_kms_key=self.s3_kms_key, + ) + def load_and_invoke(self) -> Any: """Load and deserialize the function and the arguments and then execute it.""" @@ -134,6 +175,7 @@ def load_and_invoke(self) -> Any: args, kwargs, hmac_key=self.hmac_key, + s3_base_uri=self.s3_base_uri, sagemaker_session=self.sagemaker_session, ) diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index a3d6a1d780..c4570da463 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -52,11 +52,8 @@ from sagemaker.utils import name_from_base, _tmpdir, resolve_value_from_config from sagemaker.s3 import s3_path_join, S3Uploader from sagemaker import vpc_utils -from sagemaker.remote_function.core.stored_function import StoredFunction -from sagemaker.remote_function.core.pipeline_variables import ( - Context, - convert_pipeline_variables_to_pickleable, -) +from sagemaker.remote_function.core.stored_function import StoredFunction, _SerializedData +from sagemaker.remote_function.core.pipeline_variables import Context from sagemaker.remote_function.runtime_environment.runtime_environment_manager import ( RuntimeEnvironmentManager, _DependencySettings, @@ -695,6 +692,7 @@ def compile( func_args: tuple, func_kwargs: dict, run_info=None, + serialized_data: _SerializedData = None, ) -> dict: """Build the artifacts and generate the training job request.""" from sagemaker.workflow.properties import Properties @@ -732,12 +730,8 @@ def compile( func_step_s3_dir=step_compilation_context.pipeline_build_time, ), ) - converted_func_args, converted_func_kwargs = convert_pipeline_variables_to_pickleable( - s3_base_uri=s3_base_uri, - func_args=func_args, - func_kwargs=func_kwargs, - ) - stored_function.save(func, *converted_func_args, **converted_func_kwargs) + + stored_function.save_pipeline_step_function(serialized_data) stopping_condition = { "MaxRuntimeInSeconds": job_settings.max_runtime_in_seconds, diff --git a/src/sagemaker/workflow/function_step.py b/src/sagemaker/workflow/function_step.py index da20fd93d9..4fee8ef269 100644 --- a/src/sagemaker/workflow/function_step.py +++ b/src/sagemaker/workflow/function_step.py @@ -83,6 +83,11 @@ def __init__( func_kwargs (dict): keyword arguments of the python function. **kwargs: Additional arguments to be passed to the `step` decorator. """ + from sagemaker.remote_function.core.pipeline_variables import ( + convert_pipeline_variables_to_pickleable, + ) + from sagemaker.remote_function.core.serialization import CloudpickleSerializer + from sagemaker.remote_function.core.stored_function import _SerializedData super(_FunctionStep, self).__init__( name, StepTypeEnum.TRAINING, display_name, description, depends_on, retry_policies @@ -96,6 +101,21 @@ def __init__( self.__job_settings = None + ( + self._converted_func_args, + self._converted_func_kwargs, + ) = convert_pipeline_variables_to_pickleable( + func_args=self._func_args, + func_kwargs=self._func_kwargs, + ) + + self._serialized_data = _SerializedData( + func=CloudpickleSerializer.serialize(self._func), + args=CloudpickleSerializer.serialize( + (self._converted_func_args, self._converted_func_kwargs) + ), + ) + @property def func(self): """The python function to run as a pipeline step.""" @@ -185,6 +205,7 @@ def arguments(self) -> RequestType: func=self.func, func_args=self.func_args, func_kwargs=self.func_kwargs, + serialized_data=self._serialized_data, ) # Continue to pop job name if not explicitly opted-in via config request_dict = trim_request_dict(request_dict, "TrainingJobName", step_compilation_context) diff --git a/tests/integ/sagemaker/workflow/test_step_decorator.py b/tests/integ/sagemaker/workflow/test_step_decorator.py index bd4eb4c3d1..70424383f1 100644 --- a/tests/integ/sagemaker/workflow/test_step_decorator.py +++ b/tests/integ/sagemaker/workflow/test_step_decorator.py @@ -858,3 +858,61 @@ def cuberoot(x): pipeline.delete() except Exception: pass + + +def test_step_level_serialization( + sagemaker_session, role, pipeline_name, region_name, dummy_container_without_error +): + os.environ["AWS_DEFAULT_REGION"] = region_name + + _EXPECTED_STEP_A_OUTPUT = "This pipeline is a function." + _EXPECTED_STEP_B_OUTPUT = "This generates a function arg." + + step_config = dict( + role=role, + image_uri=dummy_container_without_error, + instance_type=INSTANCE_TYPE, + ) + + # This pipeline function may clash with the pipeline object + # defined below. + # However, if the function and args serialization happen in + # step level, this clash won't happen. + def pipeline(): + return _EXPECTED_STEP_A_OUTPUT + + @step(**step_config) + def generator(): + return _EXPECTED_STEP_B_OUTPUT + + @step(**step_config) + def func_with_collision(var: str): + return f"{pipeline()} {var}" + + step_output_a = generator() + step_output_b = func_with_collision(step_output_a) + + pipeline = Pipeline( # noqa: F811 + name=pipeline_name, + steps=[step_output_b], + sagemaker_session=sagemaker_session, + ) + + try: + create_and_execute_pipeline( + pipeline=pipeline, + pipeline_name=pipeline_name, + region_name=region_name, + role=role, + no_of_steps=2, + last_step_name=get_step(step_output_b).name, + execution_parameters=dict(), + step_status="Succeeded", + step_result_type=str, + step_result_value=f"{_EXPECTED_STEP_A_OUTPUT} {_EXPECTED_STEP_B_OUTPUT}", + ) + finally: + try: + pipeline.delete() + except Exception: + pass diff --git a/tests/unit/sagemaker/remote_function/core/test_pipeline_variables.py b/tests/unit/sagemaker/remote_function/core/test_pipeline_variables.py index d936db22ee..ebe26653b8 100644 --- a/tests/unit/sagemaker/remote_function/core/test_pipeline_variables.py +++ b/tests/unit/sagemaker/remote_function/core/test_pipeline_variables.py @@ -28,6 +28,7 @@ _DelayedReturnResolver, resolve_pipeline_variables, convert_pipeline_variables_to_pickleable, + _S3BaseUriIdentifier, ) from sagemaker.workflow.parameters import ( @@ -39,6 +40,8 @@ from sagemaker.workflow.function_step import DelayedReturn from sagemaker.workflow.properties import Properties +PIPELINE_NAME = "some-pipeline" + @patch("sagemaker.remote_function.core.pipeline_variables.deserialize_obj_from_s3") def test_resolve_delayed_returns(mock_deserializer): @@ -70,6 +73,7 @@ def test_resolve_delayed_returns(mock_deserializer): _ParameterResolver(Context()), _ExecutionVariableResolver(Context()), sagemaker_session=None, + s3_base_uri=f"s3://my-bucket/{PIPELINE_NAME}", ) assert resolver.resolve(delayed_returns[0]) == 1 @@ -99,6 +103,7 @@ def test_deserializer_fails(mock_deserializer): _ParameterResolver(Context()), _ExecutionVariableResolver(Context()), sagemaker_session=None, + s3_base_uri=f"s3://my-bucket/{PIPELINE_NAME}", ) @@ -116,7 +121,12 @@ def test_no_pipeline_variables_to_resolve(mock_deserializer, func_args, func_kwa mock_deserializer.return_value = (1.0, 2.0, 3.0) resolved_args, resolved_kwargs = resolve_pipeline_variables( - Context(), func_args, func_kwargs, hmac_key="1234", sagemaker_session=None + Context(), + func_args, + func_kwargs, + hmac_key="1234", + s3_base_uri="s3://my-bucket", + sagemaker_session=None, ) assert resolved_args == func_args @@ -133,11 +143,19 @@ def test_no_pipeline_variables_to_resolve(mock_deserializer, func_args, func_kwa _ParameterFloat("parameter_2"), _ParameterBoolean("parameter_4"), _DelayedReturn( - uri=["s3://my-bucket/", _ExecutionVariable("ExecutionId"), "sub-folder-1/"], + uri=[ + _S3BaseUriIdentifier(), + _ExecutionVariable("ExecutionId"), + "sub-folder-1/", + ], reference_path=(("__getitem__", 0),), ), _DelayedReturn( - uri=["s3://my-bucket/", _ExecutionVariable("ExecutionId"), "sub-folder-1/"], + uri=[ + _S3BaseUriIdentifier(), + _ExecutionVariable("ExecutionId"), + "sub-folder-1/", + ], reference_path=(("__getitem__", 1),), ), _Properties("Steps.step_name.TrainingJobName"), @@ -154,11 +172,19 @@ def test_no_pipeline_variables_to_resolve(mock_deserializer, func_args, func_kwa "c": _ParameterFloat("parameter_2"), "d": _ParameterBoolean("parameter_4"), "e": _DelayedReturn( - uri=["s3://my-bucket/", _ExecutionVariable("ExecutionId"), "sub-folder-1/"], + uri=[ + _S3BaseUriIdentifier(), + _ExecutionVariable("ExecutionId"), + "sub-folder-1/", + ], reference_path=(("__getitem__", 0),), ), "f": _DelayedReturn( - uri=["s3://my-bucket/", _ExecutionVariable("ExecutionId"), "sub-folder-1/"], + uri=[ + _S3BaseUriIdentifier(), + _ExecutionVariable("ExecutionId"), + "sub-folder-1/", + ], reference_path=(("__getitem__", 1),), ), "g": _Properties("Steps.step_name.TrainingJobName"), @@ -184,6 +210,7 @@ def test_resolve_pipeline_variables( expected_resolved_args, expected_resolved_kwargs, ): + s3_base_uri = f"s3://my-bucket/{PIPELINE_NAME}" context = Context( property_references={ "Parameters.parameter_1": "1", @@ -192,20 +219,25 @@ def test_resolve_pipeline_variables( "Parameters.parameter_4": "true", "Execution.ExecutionId": "execution-id", "Steps.step_name.TrainingJobName": "a-cool-name", - } + }, ) mock_deserializer.return_value = (1.0, 2.0, 3.0) resolved_args, resolved_kwargs = resolve_pipeline_variables( - context, func_args, func_kwargs, hmac_key="1234", sagemaker_session=None + context, + func_args, + func_kwargs, + hmac_key="1234", + s3_base_uri=s3_base_uri, + sagemaker_session=None, ) assert resolved_args == expected_resolved_args assert resolved_kwargs == expected_resolved_kwargs mock_deserializer.assert_called_once_with( sagemaker_session=None, - s3_uri="s3://my-bucket/execution-id/sub-folder-1", + s3_uri=f"{s3_base_uri}/execution-id/sub-folder-1", hmac_key="1234", ) @@ -237,15 +269,13 @@ def test_convert_pipeline_variables_to_pickleable(): } converted_args, converted_kwargs = convert_pipeline_variables_to_pickleable( - "base_uri", func_args, func_kwargs + func_args, func_kwargs ) - print(converted_args) - assert converted_args == ( _DelayedReturn( uri=[ - "base_uri", + _S3BaseUriIdentifier(), _ExecutionVariable(name="PipelineExecutionId"), "parent_step", "results", @@ -264,7 +294,7 @@ def test_convert_pipeline_variables_to_pickleable(): assert converted_kwargs == { "a": _DelayedReturn( uri=[ - "base_uri", + _S3BaseUriIdentifier(), _ExecutionVariable(name="PipelineExecutionId"), "parent_step", "results", diff --git a/tests/unit/sagemaker/remote_function/core/test_stored_function.py b/tests/unit/sagemaker/remote_function/core/test_stored_function.py index 78b5a700da..bcc09cb585 100644 --- a/tests/unit/sagemaker/remote_function/core/test_stored_function.py +++ b/tests/unit/sagemaker/remote_function/core/test_stored_function.py @@ -24,10 +24,12 @@ from sagemaker.remote_function.core.stored_function import ( StoredFunction, JSON_SERIALIZED_RESULT_KEY, + _SerializedData, ) from sagemaker.remote_function.core.serialization import ( deserialize_obj_from_s3, serialize_obj_to_s3, + CloudpickleSerializer, ) from sagemaker.remote_function.core.pipeline_variables import ( Context, @@ -308,19 +310,13 @@ def test_load_and_invoke_json_serialization( @patch("sagemaker.s3.S3Uploader.upload_bytes", new=upload_bytes) @patch("sagemaker.s3.S3Downloader.read_bytes", new=read_bytes) -@patch("sagemaker.s3.S3Uploader.upload") -@patch("sagemaker.s3.S3Downloader.download") -def test_save_and_load_with_pipeline_variable( - s3_source_dir_download, s3_source_dir_upload, monkeypatch -): +@patch("sagemaker.s3.S3Uploader.upload", MagicMock()) +@patch("sagemaker.s3.S3Downloader.download", MagicMock()) +def test_save_and_load_with_pipeline_variable(monkeypatch): session = Mock() s3_base_uri = random_s3_uri() - job_settings = Mock() - job_settings.s3_root_uri = s3_base_uri - function_step = _FunctionStep( - name="func_1", display_name=None, description=None, job_settings=job_settings - ) + function_step = _FunctionStep(name="func_1", display_name=None, description=None) x = DelayedReturn(function_step=function_step) serialize_obj_to_s3( 3.0, session, f"{s3_base_uri}/execution-id/func_1/results", HMAC_KEY, KMS_KEY @@ -337,12 +333,11 @@ def test_save_and_load_with_pipeline_variable( "Parameters.b": "2.0", "Parameters.c": "3.0", "Execution.PipelineExecutionId": "execution-id", - } + }, ), ) func_args, func_kwargs = convert_pipeline_variables_to_pickleable( - s3_base_uri=s3_base_uri, func_args=(x,), func_kwargs={ "a": ParameterFloat("a"), @@ -350,9 +345,50 @@ def test_save_and_load_with_pipeline_variable( "c": ParameterFloat("c"), }, ) - stored_function.save(quadratic, *func_args, **func_kwargs) + + test_serialized_data = _SerializedData( + func=CloudpickleSerializer.serialize(quadratic), + args=CloudpickleSerializer.serialize((func_args, func_kwargs)), + ) + + stored_function.save_pipeline_step_function(test_serialized_data) stored_function.load_and_invoke() assert deserialize_obj_from_s3( session, s3_uri=f"{s3_base_uri}/results", hmac_key=HMAC_KEY ) == quadratic(3.0, a=1.0, b=2.0, c=3.0) + + +@patch("sagemaker.remote_function.core.serialization._upload_payload_and_metadata_to_s3") +@patch("sagemaker.remote_function.job._JobSettings") +def test_save_pipeline_step_function(mock_job_settings, upload_payload): + session = Mock() + s3_base_uri = random_s3_uri() + mock_job_settings.s3_root_uri = s3_base_uri + + stored_function = StoredFunction( + sagemaker_session=session, + s3_base_uri=s3_base_uri, + s3_kms_key=KMS_KEY, + hmac_key=HMAC_KEY, + context=Context( + step_name="step_name", + execution_id="execution_id", + ), + ) + + func_args, func_kwargs = convert_pipeline_variables_to_pickleable( + func_args=(1,), + func_kwargs={ + "a": 2, + "b": 3, + }, + ) + + test_serialized_data = _SerializedData( + func=CloudpickleSerializer.serialize(quadratic), + args=CloudpickleSerializer.serialize((func_args, func_kwargs)), + ) + stored_function.save_pipeline_step_function(test_serialized_data) + + assert upload_payload.call_count == 2 diff --git a/tests/unit/sagemaker/remote_function/test_job.py b/tests/unit/sagemaker/remote_function/test_job.py index f025276634..1884486f8b 100644 --- a/tests/unit/sagemaker/remote_function/test_job.py +++ b/tests/unit/sagemaker/remote_function/test_job.py @@ -17,9 +17,11 @@ import pytest from mock import patch, Mock, ANY, mock_open +from mock.mock import MagicMock from sagemaker.config import load_sagemaker_config from sagemaker.remote_function.checkpoint_location import CheckpointLocation +from sagemaker.remote_function.core.stored_function import _SerializedData from sagemaker.session_settings import SessionSettings from sagemaker.remote_function.spark_config import SparkConfig @@ -149,6 +151,10 @@ def job_function_with_checkpoint(a, checkpoint_1=None, *, b, checkpoint_2=None): return a + b +def serialized_data(): + return _SerializedData(func=b"serialized_func", args=b"serialized_args") + + @patch("secrets.token_hex", return_value=HMAC_KEY) @patch("sagemaker.remote_function.job.Session", return_value=mock_session()) @patch("sagemaker.remote_function.job.get_execution_role", return_value=DEFAULT_ROLE_ARN) @@ -731,7 +737,7 @@ def test_start_with_complete_job_settings( @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) -@patch("secrets.token_hex", return_value=HMAC_KEY) +@patch("secrets.token_hex", MagicMock(return_value=HMAC_KEY)) @patch( "sagemaker.remote_function.job._prepare_dependencies_and_pre_execution_scripts", return_value="some_s3_uri", @@ -750,11 +756,9 @@ def test_get_train_args_under_pipeline_context( mock_bootstrap_scripts_upload, mock_user_workspace_upload, mock_user_dependencies_upload, - secret_token, ): from sagemaker.workflow.parameters import ParameterInteger - from sagemaker.remote_function.core.pipeline_variables import _ParameterInteger mock_stored_function = Mock() mock_stored_function_ctr.return_value = mock_stored_function @@ -776,6 +780,7 @@ def test_get_train_args_under_pipeline_context( security_group_ids=["sg"], ) + mocked_serialized_data = serialized_data() s3_base_uri = f"{S3_URI}/{TEST_PIPELINE_NAME}" train_args = _Job.compile( job_settings=job_settings, @@ -784,6 +789,7 @@ def test_get_train_args_under_pipeline_context( func=job_function, func_args=(1, ParameterInteger(name="b", default_value=2)), func_kwargs={"c": 3, "d": ParameterInteger(name="d", default_value=4)}, + serialized_data=mocked_serialized_data, ) mock_stored_function_ctr.assert_called_once_with( @@ -796,11 +802,7 @@ def test_get_train_args_under_pipeline_context( func_step_s3_dir=MOCKED_PIPELINE_CONFIG.pipeline_build_time, ), ) - mock_stored_function.save.assert_called_once_with( - job_function, - *(1, _ParameterInteger(name="b")), - **{"c": 3, "d": _ParameterInteger(name="d")}, - ) + mock_stored_function.save_pipeline_step_function.assert_called_once_with(mocked_serialized_data) local_dependencies_path = mock_runtime_manager().snapshot() mock_python_version = mock_runtime_manager()._current_python_version() diff --git a/tests/unit/sagemaker/workflow/test_function_step.py b/tests/unit/sagemaker/workflow/test_function_step.py index 5e08d7005b..888635ae02 100644 --- a/tests/unit/sagemaker/workflow/test_function_step.py +++ b/tests/unit/sagemaker/workflow/test_function_step.py @@ -19,6 +19,8 @@ from mock import patch, Mock, ANY from typing import List, Tuple +from mock.mock import MagicMock + from sagemaker.workflow.parameters import ParameterInteger from sagemaker.workflow.execution_variables import ExecutionVariables from sagemaker.workflow.functions import Join @@ -105,6 +107,8 @@ def sum(a, b, c, d): assert function_step._job_settings is not None assert mock_job_settings.call_args[1]["image_uri"] == "test_image_uri" + assert function_step._serialized_data.func is not None + assert function_step._serialized_data.args is not None @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) @@ -127,6 +131,8 @@ def sum(a, b, c, d): assert function_step.description == "Returns sum of numbers" assert function_step.retry_policies == [] assert function_step.depends_on == [] + assert function_step._serialized_data.func is not None + assert function_step._serialized_data.args is not None @patch("sagemaker.workflow.utilities._pipeline_config", MOCKED_PIPELINE_CONFIG) @@ -261,11 +267,13 @@ def sum(a, b): func=ANY, func_args=(2, 3), func_kwargs={}, + serialized_data=step_output._step._serialized_data, ) mock_job_settings_ctr.assert_called_once() +@patch("sagemaker.remote_function.job._JobSettings", MagicMock()) @pytest.mark.parametrize( "type_hint", [ @@ -278,8 +286,7 @@ def sum(a, b): Tuple[int, ...], ], ) -@patch("sagemaker.remote_function.job._JobSettings") -def test_step_function_with_sequence_return_value(mock_job_settings, type_hint): +def test_step_function_with_sequence_return_value(type_hint): @step def func() -> type_hint: return 1, 2, 3 @@ -366,8 +373,9 @@ def func(): pass -@patch("sagemaker.remote_function.job._JobSettings") -def test_step_function_take_in_delayed_return_as_positional_arguments(mock_job_settings): +@patch("sagemaker.remote_function.core.serialization.CloudpickleSerializer.serialize", MagicMock()) +@patch("sagemaker.remote_function.job._JobSettings", MagicMock()) +def test_step_function_take_in_delayed_return_as_positional_arguments(): @step def func_1() -> Tuple: return 1, 2, 3 @@ -390,8 +398,9 @@ def func_2(a, b, c, param_1, param_2): get_step(func_2_output).depends_on = [] -@patch("sagemaker.remote_function.job._JobSettings") -def test_step_function_take_in_delayed_return_as_keyword_arguments(mock_job_settings): +@patch("sagemaker.remote_function.core.serialization.CloudpickleSerializer.serialize", MagicMock()) +@patch("sagemaker.remote_function.job._JobSettings", MagicMock()) +def test_step_function_take_in_delayed_return_as_keyword_arguments(): @step def func_1() -> Tuple: return 1, 2, 3 @@ -414,8 +423,9 @@ def func_2(a, b, c, param_1, param_2): get_step(func_2_output).depends_on = [] -@patch("sagemaker.remote_function.job._JobSettings") -def test_delayed_returns_in_nested_object_are_ignored(mock_job_settings): +@patch("sagemaker.remote_function.core.serialization.CloudpickleSerializer.serialize", MagicMock()) +@patch("sagemaker.remote_function.job._JobSettings", MagicMock()) +def test_delayed_returns_in_nested_object_are_ignored(): @step def func_1() -> Tuple: return 1, 2, 3 @@ -438,8 +448,9 @@ def func_2(data, param_1, param_2): assert get_step(func_2_output).depends_on == [] -@patch("sagemaker.remote_function.job._JobSettings") -def test_unsupported_pipeline_variables_as_function_arguments(mock_job_settings): +@patch("sagemaker.remote_function.core.serialization.CloudpickleSerializer.serialize", MagicMock()) +@patch("sagemaker.remote_function.job._JobSettings", MagicMock()) +def test_unsupported_pipeline_variables_as_function_arguments(): @step def func_1() -> Tuple: return 1, 2, 3 @@ -461,8 +472,9 @@ def func_2(a, b, c, param_1, param_2): assert "Properties attribute is not supported for _FunctionStep" in str(e.value) -@patch("sagemaker.remote_function.job._JobSettings") -def test_both_data_and_execution_dependency_between_steps(mock_job_settings): +@patch("sagemaker.remote_function.core.serialization.CloudpickleSerializer.serialize", MagicMock()) +@patch("sagemaker.remote_function.job._JobSettings", MagicMock()) +def test_both_data_and_execution_dependency_between_steps(): @step def func_0() -> None: pass @@ -491,8 +503,8 @@ def func_2(a, b, c, param_1, param_2): get_step(func_2_output).depends_on = [] -@patch("sagemaker.remote_function.job._JobSettings") -def test_disable_deepcopy_of_delayed_return(mock_job_settings): +@patch("sagemaker.remote_function.job._JobSettings", MagicMock()) +def test_disable_deepcopy_of_delayed_return(): @step def func(): return 1