diff --git a/Makefile b/Makefile index aaf1736258..3ab83da892 100644 --- a/Makefile +++ b/Makefile @@ -169,17 +169,17 @@ test: # help: test-verbose - Run all tests verbosely .PHONY: test-verbose test-verbose: - @python -m pytest -vv --ignore=tests/full_wlm/ + @python -m pytest -vv --ignore=tests/full_wlm/ --ignore=tests/dragon # help: test-debug - Run all tests with debug output .PHONY: test-debug test-debug: - @SMARTSIM_LOG_LEVEL=developer python -m pytest -s -o log_cli=true -vv --ignore=tests/full_wlm/ + @SMARTSIM_LOG_LEVEL=developer python -m pytest -s -o log_cli=true -vv --ignore=tests/full_wlm/ --ignore=tests/dragon # help: test-cov - Run all tests with coverage .PHONY: test-cov test-cov: - @python -m pytest -vv --cov=./smartsim --cov-config=${COV_FILE} --ignore=tests/full_wlm/ + @python -m pytest -vv --cov=./smartsim --cov-config=${COV_FILE} --ignore=tests/full_wlm/ --ignore=tests/dragon # help: test-full - Run all WLM tests with Python coverage (full test suite) diff --git a/doc/changelog.md b/doc/changelog.md index ee41fabf88..28de6e8f95 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -14,6 +14,7 @@ Jump to: Description - Add TorchWorker first implementation and mock inference app example +- Add error handling in Worker Manager pipeline - Add EnvironmentConfigLoader for ML Worker Manager - Add Model schema with model metadata included - Removed device from schemas, MessageHandler and tests diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 8c06351fb5..8e3ed3fb4c 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -51,13 +51,13 @@ MachineLearningWorkerBase, ) from ...message_handler import MessageHandler -from ...mli_schemas.response.response_capnp import Response +from ...mli_schemas.response.response_capnp import Response, ResponseBuilder if t.TYPE_CHECKING: from dragon.fli import FLInterface from smartsim._core.mli.mli_schemas.model.model_capnp import Model - from smartsim._core.mli.mli_schemas.response.response_capnp import StatusEnum + from smartsim._core.mli.mli_schemas.response.response_capnp import Status logger = get_logger(__name__) @@ -98,6 +98,7 @@ def deserialize_message( input_bytes: t.Optional[t.List[bytes]] = ( None # these will really be tensors already ) + output_keys: t.Optional[t.List[str]] = None input_meta: t.List[t.Any] = [] @@ -107,22 +108,26 @@ def deserialize_message( input_bytes = [data.blob for data in request.input.data] input_meta = [data.tensorDescriptor for data in request.input.data] + if request.output: + output_keys = [tensor_key.key for tensor_key in request.output] + inference_request = InferenceRequest( model_key=model_key, callback=comm_channel, raw_inputs=input_bytes, - input_meta=input_meta, input_keys=input_keys, + input_meta=input_meta, + output_keys=output_keys, raw_model=model_bytes, batch_size=0, ) return inference_request -def build_failure_reply(status: "StatusEnum", message: str) -> Response: +def build_failure_reply(status: "Status", message: str) -> ResponseBuilder: return MessageHandler.build_response( - status=status, # todo: need to indicate correct status - message=message, # todo: decide what these will be + status=status, + message=message, result=[], custom_attributes=None, ) @@ -154,17 +159,39 @@ def prepare_outputs(reply: InferenceReply) -> t.List[t.Any]: return prepared_outputs -def build_reply(reply: InferenceReply) -> Response: +def build_reply(reply: InferenceReply) -> ResponseBuilder: results = prepare_outputs(reply) return MessageHandler.build_response( - status="complete", - message="success", + status=reply.status_enum, + message=reply.message, result=results, custom_attributes=None, ) +def exception_handler( + exc: Exception, reply_channel: t.Optional[CommChannelBase], failure_message: str +) -> None: + """ + Logs exceptions and sends a failure response. + + :param exc: The exception to be logged + :param reply_channel: The channel used to send replies + :param failure_message: Failure message to log and send back + """ + logger.exception( + f"{failure_message}\n" + f"Exception type: {type(exc).__name__}\n" + f"Exception message: {str(exc)}" + ) + serialized_resp = MessageHandler.serialize_response( + build_failure_reply("fail", failure_message) + ) + if reply_channel: + reply_channel.send(serialized_resp) + + class WorkerManager(Service): """An implementation of a service managing distribution of tasks to machine learning workers""" @@ -258,96 +285,147 @@ def _on_iteration(self) -> None: timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing + reply = InferenceReply() + if not request.raw_model: if request.model_key is None: - # A valid request should never get here. - raise ValueError("Could not read model key") + exception_handler( + ValueError("Could not find model key or model"), + request.callback, + "Could not find model key or model.", + ) + return if request.model_key in self._cached_models: timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing model_result = LoadModelResult(self._cached_models[request.model_key]) else: - fetch_model_result = None - while True: - try: - interm = time.perf_counter() # timing - fetch_model_result = self._worker.fetch_model( - request, self._feature_store - ) - except KeyError: - time.sleep(0.1) - else: - break - - if fetch_model_result is None: - raise SmartSimError("Could not retrieve model from feature store") timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing + try: + fetch_model_result = self._worker.fetch_model( + request, self._feature_store + ) + except Exception as e: + exception_handler( + e, request.callback, "Failed while fetching the model." + ) + return + + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing + try: + model_result = self._worker.load_model( + request, + fetch_result=fetch_model_result, + device=self._device, + ) + self._cached_models[request.model_key] = model_result.model + except Exception as e: + exception_handler( + e, request.callback, "Failed while loading the model." + ) + return + + else: + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing + try: + fetch_model_result = self._worker.fetch_model( + request, self._feature_store + ) + except Exception as e: + exception_handler( + e, request.callback, "Failed while fetching the model." + ) + return + + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing + try: model_result = self._worker.load_model( - request, fetch_model_result, self._device + request, fetch_result=fetch_model_result, device=self._device ) - self._cached_models[request.model_key] = model_result.model - else: - fetch_model_result = self._worker.fetch_model(request, None) - model_result = self._worker.load_model( - request, fetch_result=fetch_model_result, device=self._device - ) + except Exception as e: + exception_handler( + e, request.callback, "Failed while loading the model." + ) + return timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing - fetch_input_result = self._worker.fetch_inputs(request, self._feature_store) + try: + fetch_input_result = self._worker.fetch_inputs(request, self._feature_store) + except Exception as e: + exception_handler(e, request.callback, "Failed while fetching the inputs.") + return timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing - transformed_input = self._worker.transform_input( - request, fetch_input_result, self._device - ) + try: + transformed_input = self._worker.transform_input( + request, fetch_input_result, self._device + ) + except Exception as e: + exception_handler( + e, request.callback, "Failed while transforming the input." + ) + return timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing - - reply = InferenceReply() - try: execute_result = self._worker.execute( request, model_result, transformed_input ) + except Exception as e: + exception_handler(e, request.callback, "Failed while executing.") + return - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing + try: transformed_output = self._worker.transform_output( request, execute_result, self._device ) + except Exception as e: + exception_handler( + e, request.callback, "Failed while transforming the output." + ) + return - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - if request.output_keys: + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing + if request.output_keys: + try: reply.output_keys = self._worker.place_output( - request, transformed_output, self._feature_store + request, + transformed_output, + self._feature_store, ) - else: - reply.outputs = transformed_output.outputs - except Exception: - logger.exception("Error executing worker") - reply.failed = True + except Exception as e: + exception_handler( + e, request.callback, "Failed while placing the output." + ) + return + else: + reply.outputs = transformed_output.outputs timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing - if reply.failed: - response = build_failure_reply("fail", "failure-occurred") + if reply.outputs is None or not reply.outputs: + response = build_failure_reply("fail", "Outputs not found.") else: - if reply.outputs is None or not reply.outputs: - response = build_failure_reply("fail", "no-results") - + reply.status_enum = "complete" + reply.message = "Success" response = build_reply(reply) timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing - # serialized = self._worker.serialize_reply(request, transformed_output) - serialized_resp = MessageHandler.serialize_response(response) # type: ignore + serialized_resp = MessageHandler.serialize_response(response) timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing diff --git a/smartsim/_core/mli/infrastructure/worker/worker.py b/smartsim/_core/mli/infrastructure/worker/worker.py index 900a8241de..dd874abe39 100644 --- a/smartsim/_core/mli/infrastructure/worker/worker.py +++ b/smartsim/_core/mli/infrastructure/worker/worker.py @@ -33,6 +33,9 @@ from ...infrastructure.storage.featurestore import FeatureStore from ...mli_schemas.model.model_capnp import Model +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import Status + logger = get_logger(__name__) @@ -70,12 +73,14 @@ def __init__( self, outputs: t.Optional[t.Collection[t.Any]] = None, output_keys: t.Optional[t.Collection[str]] = None, - failed: bool = False, + status_enum: "Status" = "running", + message: str = "In progress", ) -> None: """Initialize the object""" self.outputs: t.Collection[t.Any] = outputs or [] self.output_keys: t.Collection[t.Optional[str]] = output_keys or [] - self.failed = failed + self.status_enum = status_enum + self.message = message class LoadModelResult: diff --git a/smartsim/_core/mli/message_handler.py b/smartsim/_core/mli/message_handler.py index bcf1cfdf14..4fe2bef3a7 100644 --- a/smartsim/_core/mli/message_handler.py +++ b/smartsim/_core/mli/message_handler.py @@ -360,7 +360,7 @@ def build_request( request_attributes_capnp.TensorFlowRequestAttributes, None, ], - ) -> request_capnp.Request: + ) -> request_capnp.RequestBuilder: """ Builds the request message. @@ -405,7 +405,7 @@ def deserialize_request(request_bytes: bytes) -> request_capnp.Request: @staticmethod def _assign_status( - response: response_capnp.Response, status: "response_capnp.StatusEnum" + response: response_capnp.Response, status: "response_capnp.Status" ) -> None: """ Assigns a status to the supplied response. @@ -498,7 +498,7 @@ def _assign_custom_response_attributes( @staticmethod def build_response( - status: "response_capnp.StatusEnum", + status: "response_capnp.Status", message: str, result: t.Union[ t.List[tensor_capnp.Tensor], t.List[data_references_capnp.TensorKey] @@ -508,7 +508,7 @@ def build_response( response_attributes_capnp.TensorFlowResponseAttributes, None, ], - ) -> response_capnp.Response: + ) -> response_capnp.ResponseBuilder: """ Builds the response message. diff --git a/smartsim/_core/mli/mli_schemas/response/response.capnp b/smartsim/_core/mli/mli_schemas/response/response.capnp index 67375b5a97..83aa05a41b 100644 --- a/smartsim/_core/mli/mli_schemas/response/response.capnp +++ b/smartsim/_core/mli/mli_schemas/response/response.capnp @@ -30,14 +30,15 @@ using Tensors = import "../tensor/tensor.capnp"; using ResponseAttributes = import "response_attributes/response_attributes.capnp"; using DataRef = import "../data/data_references.capnp"; -enum StatusEnum { +enum Status { complete @0; fail @1; timeout @2; + running @3; } struct Response { - status @0 :StatusEnum; + status @0 :Status; message @1 :Text; result :union { keys @2 :List(DataRef.TensorKey); diff --git a/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi b/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi index f6d7f8444e..f19bdefe04 100644 --- a/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi +++ b/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi @@ -45,7 +45,7 @@ from .response_attributes.response_attributes_capnp import ( TorchResponseAttributesReader, ) -StatusEnum = Literal["complete", "fail", "timeout"] +Status = Literal["complete", "fail", "timeout", "running"] class Response: class Result: @@ -150,7 +150,7 @@ class Response: def write(file: BufferedWriter) -> None: ... @staticmethod def write_packed(file: BufferedWriter) -> None: ... - status: StatusEnum + status: Status message: str result: Response.Result | Response.ResultBuilder | Response.ResultReader customAttributes: ( diff --git a/tests/dragon/test_error_handling.py b/tests/dragon/test_error_handling.py new file mode 100644 index 0000000000..151bdd2fcc --- /dev/null +++ b/tests/dragon/test_error_handling.py @@ -0,0 +1,270 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import pickle +from unittest.mock import MagicMock + +import pytest + +dragon = pytest.importorskip("dragon") + +import dragon.utils as du +from dragon.channels import Channel +from dragon.data.ddict.ddict import DDict +from dragon.fli import FLInterface + +from smartsim._core.mli.infrastructure.control.workermanager import ( + WorkerManager, + exception_handler, +) +from smartsim._core.mli.infrastructure.environmentloader import EnvironmentConfigLoader +from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import ( + DragonFeatureStore, +) +from smartsim._core.mli.infrastructure.worker.worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + InferenceReply, + LoadModelResult, + TransformInputResult, + TransformOutputResult, +) +from smartsim._core.mli.message_handler import MessageHandler + +from .utils.channel import FileSystemCommChannel +from .utils.worker import IntegratedTorchWorker + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +@pytest.fixture +def setup_worker_manager_model_bytes(test_dir, monkeypatch: pytest.MonkeyPatch): + integrated_worker = IntegratedTorchWorker() + + chan = Channel.make_process_local() + queue = FLInterface(main_ch=chan) + monkeypatch.setenv("SSQueue", du.B64.bytes_to_str(queue.serialize())) + storage = DDict() + feature_store = DragonFeatureStore(storage) + monkeypatch.setenv( + "SSFeatureStore", base64.b64encode(pickle.dumps(feature_store)).decode("utf-8") + ) + + worker_manager = WorkerManager( + EnvironmentConfigLoader(), + integrated_worker, + as_service=False, + cooldown=3, + comm_channel_type=FileSystemCommChannel, + ) + + tensor_key = MessageHandler.build_tensor_key("key") + model = MessageHandler.build_model(b"model", "model name", "v 0.0.1") + request = MessageHandler.build_request( + test_dir, model, [tensor_key], [tensor_key], [], None + ) + ser_request = MessageHandler.serialize_request(request) + worker_manager._task_queue.send(ser_request) + + return worker_manager, integrated_worker + + +@pytest.fixture +def setup_worker_manager_model_key(test_dir, monkeypatch: pytest.MonkeyPatch): + integrated_worker = IntegratedTorchWorker() + + chan = Channel.make_process_local() + queue = FLInterface(main_ch=chan) + monkeypatch.setenv("SSQueue", du.B64.bytes_to_str(queue.serialize())) + storage = DDict() + feature_store = DragonFeatureStore(storage) + monkeypatch.setenv( + "SSFeatureStore", base64.b64encode(pickle.dumps(feature_store)).decode("utf-8") + ) + + worker_manager = WorkerManager( + EnvironmentConfigLoader(), + integrated_worker, + as_service=False, + cooldown=3, + comm_channel_type=FileSystemCommChannel, + ) + + tensor_key = MessageHandler.build_tensor_key("key") + model_key = MessageHandler.build_model_key("model key") + request = MessageHandler.build_request( + test_dir, model_key, [tensor_key], [tensor_key], [], None + ) + ser_request = MessageHandler.serialize_request(request) + worker_manager._task_queue.send(ser_request) + + return worker_manager, integrated_worker + + +def mock_pipeline_stage(monkeypatch: pytest.MonkeyPatch, integrated_worker, stage): + def mock_stage(*args, **kwargs): + raise ValueError(f"Simulated error in {stage}") + + monkeypatch.setattr(integrated_worker, stage, mock_stage) + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, + ) + + def mock_exception_handler(exc, reply_channel, failure_message): + return exception_handler(exc, None, failure_message) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", + mock_exception_handler, + ) + + return mock_reply_fn + + +@pytest.mark.parametrize( + "setup_worker_manager", + [ + pytest.param("setup_worker_manager_model_bytes"), + pytest.param("setup_worker_manager_model_key"), + ], +) +@pytest.mark.parametrize( + "stage, error_message", + [ + pytest.param( + "fetch_model", "Failed while fetching the model.", id="fetch model" + ), + pytest.param("load_model", "Failed while loading the model.", id="load model"), + pytest.param( + "fetch_inputs", "Failed while fetching the inputs.", id="fetch inputs" + ), + pytest.param( + "transform_input", + "Failed while transforming the input.", + id="transform inputs", + ), + pytest.param("execute", "Failed while executing.", id="execute"), + pytest.param( + "transform_output", + "Failed while transforming the output.", + id="transform output", + ), + pytest.param( + "place_output", "Failed while placing the output.", id="place output" + ), + ], +) +def test_pipeline_stage_errors_handled( + request, + setup_worker_manager, + monkeypatch: pytest.MonkeyPatch, + stage: str, + error_message: str, +): + """Ensures that the worker manager does not crash after a failure in various pipeline stages""" + worker_manager, integrated_worker = request.getfixturevalue(setup_worker_manager) + mock_reply_fn = mock_pipeline_stage(monkeypatch, integrated_worker, stage) + + if stage not in ["fetch_model"]: + monkeypatch.setattr( + integrated_worker, + "fetch_model", + MagicMock(return_value=FetchModelResult(b"result_bytes")), + ) + + if stage not in ["fetch_model", "load_model"]: + monkeypatch.setattr( + integrated_worker, + "load_model", + MagicMock(return_value=LoadModelResult(b"result_bytes")), + ) + if stage not in ["fetch_model", "load_model", "fetch_inputs"]: + monkeypatch.setattr( + integrated_worker, + "fetch_inputs", + MagicMock(return_value=FetchInputResult([b"result_bytes"], None)), + ) + if stage not in ["fetch_model", "load_model", "fetch_inputs", "transform_input"]: + monkeypatch.setattr( + integrated_worker, + "transform_input", + MagicMock(return_value=TransformInputResult(b"result_bytes")), + ) + if stage not in [ + "fetch_model", + "load_model", + "fetch_inputs", + "transform_input", + "execute", + ]: + monkeypatch.setattr( + integrated_worker, + "execute", + MagicMock(return_value=ExecuteResult(b"result_bytes")), + ) + if stage not in [ + "fetch_model", + "load_model", + "fetch_inputs", + "transform_input", + "execute", + "transform_output", + ]: + monkeypatch.setattr( + integrated_worker, + "transform_output", + MagicMock( + return_value=TransformOutputResult(b"result", [], "c", "float32") + ), + ) + + worker_manager._on_iteration() + + mock_reply_fn.assert_called_once() + mock_reply_fn.assert_called_with("fail", error_message) + + +def test_exception_handling_helper(monkeypatch: pytest.MonkeyPatch): + """Ensures that the worker manager does not crash after a failure in the + execute pipeline stage""" + reply = InferenceReply() + + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, + ) + + test_exception = ValueError("Test ValueError") + exception_handler(test_exception, None, "Failure while fetching the model.") + + mock_reply_fn.assert_called_once() + mock_reply_fn.assert_called_with("fail", "Failure while fetching the model.") diff --git a/tests/dragon/test_reply_building.py b/tests/dragon/test_reply_building.py new file mode 100644 index 0000000000..d1c4d226bb --- /dev/null +++ b/tests/dragon/test_reply_building.py @@ -0,0 +1,91 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.control.workermanager import ( + build_failure_reply, + build_reply, +) +from smartsim._core.mli.infrastructure.worker.worker import InferenceReply + +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import Status + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +@pytest.mark.parametrize( + "status, message", + [ + pytest.param("timeout", "Worker timed out", id="timeout"), + pytest.param("fail", "Failed while executing", id="fail"), + ], +) +def test_build_failure_reply(status: "Status", message: str): + "Ensures failure replies can be built successfully" + response = build_failure_reply(status, message) + assert response.status == status + assert response.message == message + + +def test_build_failure_reply_fails(): + "Ensures ValueError is raised if a Status Enum is not used" + with pytest.raises(ValueError) as ex: + response = build_failure_reply("not a status enum", "message") + + assert "Error assigning status to response" in ex.value.args[0] + + +@pytest.mark.parametrize( + "status, message", + [ + pytest.param("complete", "Success", id="complete"), + ], +) +def test_build_reply(status: "Status", message: str): + "Ensures replies can be built successfully" + reply = InferenceReply() + reply.status_enum = status + reply.message = message + response = build_reply(reply) + assert response.status == status + assert response.message == message + + +def test_build_reply_fails(): + "Ensures ValueError is raised if a Status Enum is not used" + with pytest.raises(ValueError) as ex: + reply = InferenceReply() + reply.status_enum = "not a status enum" + response = build_reply(reply) + + assert "Error assigning status to response" in ex.value.args[0] diff --git a/tests/dragon/utils/channel.py b/tests/dragon/utils/channel.py new file mode 100644 index 0000000000..df76c484b5 --- /dev/null +++ b/tests/dragon/utils/channel.py @@ -0,0 +1,64 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pathlib +import typing as t + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class FileSystemCommChannel(CommChannelBase): + """Passes messages by writing to a file""" + + def __init__(self, key: t.Union[bytes, pathlib.Path]) -> None: + """Initialize the FileSystemCommChannel instance""" + if not isinstance(key, bytes): + super().__init__(key.as_posix().encode("utf-8")) + self._file_path = key + else: + super().__init__(key) + self._file_path = pathlib.Path(key.decode("utf-8")) + + if not self._file_path.parent.exists(): + self._file_path.parent.mkdir(parents=True) + + self._file_path.touch() + + def send(self, value: bytes) -> None: + """Send a message throuh the underlying communication channel + :param value: The value to send""" + logger.debug( + f"Channel {self.descriptor.decode('utf-8')} sending message to {self._file_path}" + ) + self._file_path.write_bytes(value) + + def recv(self) -> bytes: + """Receieve a message through the underlying communication channel + :returns: the received message""" + ... diff --git a/tests/dragon/utils/worker.py b/tests/dragon/utils/worker.py new file mode 100644 index 0000000000..b1de280185 --- /dev/null +++ b/tests/dragon/utils/worker.py @@ -0,0 +1,128 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import typing as t + +import torch + +import smartsim._core.mli.infrastructure.worker.worker as mliw +import smartsim.error as sse +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class IntegratedTorchWorker(mliw.MachineLearningWorkerBase): + """A minimum implementation of a worker that executes a PyTorch model""" + + # @staticmethod + # def deserialize(request: InferenceRequest) -> t.List[t.Any]: + # # request.input_meta + # # request.raw_inputs + # return request + + @staticmethod + def load_model( + request: mliw.InferenceRequest, fetch_result: mliw.FetchModelResult + ) -> mliw.LoadModelResult: + model_bytes = fetch_result.model_bytes or request.raw_model + if not model_bytes: + raise ValueError("Unable to load model without reference object") + + model: torch.nn.Module = torch.load(io.BytesIO(model_bytes)) + result = mliw.LoadModelResult(model) + return result + + @staticmethod + def transform_input( + request: mliw.InferenceRequest, + fetch_result: mliw.FetchInputResult, + ) -> mliw.TransformInputResult: + # extra metadata for assembly can be found in request.input_meta + raw_inputs = request.raw_inputs or fetch_result.inputs + + result: t.List[torch.Tensor] = [] + # should this happen here? + # consider - fortran to c data layout + # is there an intermediate representation before really doing torch.load? + if raw_inputs: + result = [torch.load(io.BytesIO(item)) for item in raw_inputs] + + return mliw.TransformInputResult(result) + + @staticmethod + def execute( + request: mliw.InferenceRequest, + load_result: mliw.LoadModelResult, + transform_result: mliw.TransformInputResult, + ) -> mliw.ExecuteResult: + if not load_result.model: + raise sse.SmartSimError("Model must be loaded to execute") + + model = load_result.model + results = [model(tensor) for tensor in transform_result.transformed] + + execute_result = mliw.ExecuteResult(results) + return execute_result + + @staticmethod + def transform_output( + request: mliw.InferenceRequest, + execute_result: mliw.ExecuteResult, + ) -> mliw.TransformOutputResult: + # transformed = [item.clone() for item in execute_result.predictions] + # return OutputTransformResult(transformed) + + # transformed = [item.bytes() for item in execute_result.predictions] + + # OutputTransformResult.transformed SHOULD be a list of + # capnproto Tensors Or tensor descriptors accompanying bytes + + # send the original tensors... + execute_result.predictions = [t.detach() for t in execute_result.predictions] + # todo: solve sending all tensor metadata that coincisdes with each prediction + return mliw.TransformOutputResult( + execute_result.predictions, [1], "c", "float32" + ) + # return OutputTransformResult(transformed) + + # @staticmethod + # def serialize_reply( + # request: InferenceRequest, results: OutputTransformResult + # ) -> t.Any: + # # results = IntegratedTorchWorker._prepare_outputs(results.outputs) + # # return results + # return None + # # response = MessageHandler.build_response( + # # status=200, # todo: are we satisfied with 0/1 (success, fail) + # # # todo: if not detailed messages, this shouldn't be returned. + # # message="success", + # # result=results, + # # custom_attributes=None, + # # ) + # # serialized_resp = MessageHandler.serialize_response(response) + # # return serialized_resp diff --git a/tests/mli/test_worker_manager.py b/tests/mli/test_worker_manager.py index 7b345f9ef1..df4b0a637f 100644 --- a/tests/mli/test_worker_manager.py +++ b/tests/mli/test_worker_manager.py @@ -149,6 +149,7 @@ def mock_messages( model=message_model_key, inputs=[message_tensor_input_key], outputs=[message_tensor_output_key], + output_descriptors=[], custom_attributes=None, ) request_bytes = MessageHandler.serialize_request(request)