From 345c739d22742f5fc4c367284964e35301e3d0d5 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Mon, 1 Jul 2024 15:20:54 -0700 Subject: [PATCH 01/57] started tests --- tests/mli/test_error_handling.py | 128 +++++++++++++++++++++++++++++++ 1 file changed, 128 insertions(+) create mode 100644 tests/mli/test_error_handling.py diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py new file mode 100644 index 0000000000..bd10f1113b --- /dev/null +++ b/tests/mli/test_error_handling.py @@ -0,0 +1,128 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import logging +import multiprocessing as mp +import pathlib +import time +import typing as t + +import pytest +import torch + +from smartsim._core.mli.infrastructure.control.workermanager import WorkerManager +from smartsim._core.mli.infrastructure.storage.featurestore import FeatureStore +from smartsim._core.mli.message_handler import MessageHandler +from smartsim.log import get_logger + +from .channel import FileSystemCommChannel +from .featurestore import FileSystemFeatureStore +from .worker import IntegratedTorchWorker + +work_queue: "mp.Queue[bytes]" = mp.Queue() +integrated_worker = IntegratedTorchWorker() +file_system_store = FileSystemFeatureStore() + + +worker_manager = WorkerManager( + work_queue, + integrated_worker, + file_system_store, + as_service=False, + cooldown=10, + comm_channel_type=FileSystemCommChannel, +) +tensor_key = MessageHandler.build_tensor_key("key") +request = MessageHandler.build_request(b"channel", b"model", [tensor_key], [tensor_key], [], None) +ser_request = MessageHandler.serialize_request(request) + +def test_execute_errors_handled(monkeypatch): + + def mock_execute(): + raise ValueError("Simulated error in execute") + + work_queue.put(ser_request) + + monkeypatch.setattr(integrated_worker, "execute", mock_execute) + + worker_manager._on_iteration() + + +def test_fetch_model_errors_handled(monkeypatch): + + def mock_fetch_model(a, b): + raise ValueError("Simulated error in fetch_model") + + work_queue.put(ser_request) + + monkeypatch.setattr(integrated_worker, "fetch_model", mock_fetch_model) + + worker_manager._on_iteration() + + +def test_load_model_errors_handled(monkeypatch): + + def mock_load_model(a, b): + raise ValueError("Simulated error in load_model") + + work_queue.put(ser_request) + + monkeypatch.setattr(integrated_worker, "load_model", mock_load_model) + worker_manager._on_iteration() + + +def test_fetch_inputs_errors_handled(monkeypatch): + + def mock_fetch_inputs(a, b): + raise ValueError("Simulated error in fetch_inputs") + + work_queue.put(ser_request) + + monkeypatch.setattr(integrated_worker, "fetch_inputs", mock_fetch_inputs) + worker_manager._on_iteration() + + +def test_transform_input_errors_handled(monkeypatch): + + def mock_transform_input(a, b): + raise ValueError("Simulated error in transform_input") + + work_queue.put(ser_request) + + monkeypatch.setattr(integrated_worker, "transform_input", mock_transform_input) + worker_manager._on_iteration() + + +def test_transform_output_errors_handled(monkeypatch): + + def mock_transform_output(a, b): + raise ValueError("Simulated error in transform_output") + + work_queue.put(ser_request) + + monkeypatch.setattr(integrated_worker, "transform_output", mock_transform_output) + worker_manager._on_iteration() \ No newline at end of file From 68ed0903ac588314cd3b9ec3a91ed0ce18909e7b Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Mon, 1 Jul 2024 15:40:50 -0700 Subject: [PATCH 02/57] more tests --- tests/mli/test_error_handling.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index bd10f1113b..a1a5a27ba1 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -27,9 +27,6 @@ import io import logging import multiprocessing as mp -import pathlib -import time -import typing as t import pytest import torch @@ -125,4 +122,15 @@ def mock_transform_output(a, b): work_queue.put(ser_request) monkeypatch.setattr(integrated_worker, "transform_output", mock_transform_output) + worker_manager._on_iteration() + + +def test_place_output_errors_handled(monkeypatch): + + def mock_place_output(a, b, c): + raise ValueError("Simulated error in place_output") + + work_queue.put(ser_request) + + monkeypatch.setattr(integrated_worker, "place_output", mock_place_output) worker_manager._on_iteration() \ No newline at end of file From 548caa40ee3ffa3b93c33b7009a33f622efdead5 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Mon, 1 Jul 2024 15:54:03 -0700 Subject: [PATCH 03/57] starting to handle errors --- channel | 0 .../mli/infrastructure/control/workermanager.py | 17 ++++++++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 channel diff --git a/channel b/channel new file mode 100644 index 0000000000..e69de29bb2 diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index b113f9187e..3d5c435568 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -239,7 +239,22 @@ def _on_iteration(self) -> None: # # let the worker perform additional custom deserialization # request = self._worker.deserialize(request_bytes) - fetch_model_result = self._worker.fetch_model(request, self._feature_store) + try: + fetch_model_result = self._worker.fetch_model(request, self._feature_store) + try: + model_result = self._worker.load_model(request, fetch_model_result) + except Exception as e: + logger.exception( + f"An error occurred while loading the model." + f"Exception type: {type(e).__name__}." + f"Exception message: {str(e)}" + ) + except Exception as e: + logger.exception( + f"An error occurred while fetching the model." + f"Exception type: {type(e).__name__}." + f"Exception message: {str(e)}" + ) model_result = self._worker.load_model(request, fetch_model_result) fetch_input_result = self._worker.fetch_inputs(request, self._feature_store) transformed_input = self._worker.transform_input(request, fetch_input_result) From 57faf7006b09d60dc520735c0af7be8b462e55d8 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Tue, 2 Jul 2024 09:26:12 -0700 Subject: [PATCH 04/57] nested try excepts. not good. --- channel | 0 .../infrastructure/control/workermanager.py | 102 ++++++++++++------ tests/mli/test_error_handling.py | 62 +++++------ 3 files changed, 102 insertions(+), 62 deletions(-) delete mode 100644 channel diff --git a/channel b/channel deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 3d5c435568..890cbb3e9c 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -239,47 +239,85 @@ def _on_iteration(self) -> None: # # let the worker perform additional custom deserialization # request = self._worker.deserialize(request_bytes) + reply = InferenceReply() + try: fetch_model_result = self._worker.fetch_model(request, self._feature_store) try: model_result = self._worker.load_model(request, fetch_model_result) + try: + fetch_input_result = self._worker.fetch_inputs( + request, self._feature_store + ) + try: + transformed_input = self._worker.transform_input( + request, fetch_input_result + ) + try: + execute_result = self._worker.execute( + request, model_result, transformed_input + ) + try: + transformed_output = self._worker.transform_output( + request, execute_result + ) + if request.output_keys: + try: + reply.output_keys = self._worker.place_output( + request, + transformed_output, + self._feature_store, + ) + except Exception as e: + logger.exception( + f"An error occurred while placing the output." + f"Exception type: {type(e).__name__}." + f"Exception message: {str(e)}" + ) + reply.failed = True + else: + reply.outputs = transformed_output.outputs + except Exception as e: + logger.exception( + f"An error occurred while transforming the output." + f"Exception type: {type(e).__name__}." + f"Exception message: {str(e)}" + ) + reply.failed = True + except Exception as e: + logger.exception( + f"An error occurred while executing." + f"Exception type: {type(e).__name__}." + f"Exception message: {str(e)}" + ) + reply.failed = True + except Exception as e: + logger.exception( + f"An error occurred while transforming the input." + f"Exception type: {type(e).__name__}." + f"Exception message: {str(e)}" + ) + reply.failed = True + except Exception as e: + logger.exception( + f"An error occurred while fetching the model." + f"Exception type: {type(e).__name__}." + f"Exception message: {str(e)}" + ) + reply.failed = True except Exception as e: logger.exception( - f"An error occurred while loading the model." - f"Exception type: {type(e).__name__}." - f"Exception message: {str(e)}" - ) + f"An error occurred while loading the model." + f"Exception type: {type(e).__name__}." + f"Exception message: {str(e)}" + ) + reply.failed = True except Exception as e: logger.exception( - f"An error occurred while fetching the model." - f"Exception type: {type(e).__name__}." - f"Exception message: {str(e)}" - ) - model_result = self._worker.load_model(request, fetch_model_result) - fetch_input_result = self._worker.fetch_inputs(request, self._feature_store) - transformed_input = self._worker.transform_input(request, fetch_input_result) - - # batch: t.Collection[_Datum] = transform_result.transformed_input - # if self._batch_size: - # batch = self._worker.batch_requests(transform_result, self._batch_size) - - reply = InferenceReply() - - try: - execute_result = self._worker.execute( - request, model_result, transformed_input + f"An error occurred while fetching the model." + f"Exception type: {type(e).__name__}." + f"Exception message: {str(e)}" ) - - transformed_output = self._worker.transform_output(request, execute_result) - - if request.output_keys: - reply.output_keys = self._worker.place_output( - request, transformed_output, self._feature_store - ) - else: - reply.outputs = transformed_output.outputs - except Exception: - logger.exception("Error executing worker") reply.failed = True if reply.failed: diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index a1a5a27ba1..556e59c9d2 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -54,83 +54,85 @@ comm_channel_type=FileSystemCommChannel, ) tensor_key = MessageHandler.build_tensor_key("key") -request = MessageHandler.build_request(b"channel", b"model", [tensor_key], [tensor_key], [], None) +request = MessageHandler.build_request( + b"channel", b"model", [tensor_key], [tensor_key], [], None +) ser_request = MessageHandler.serialize_request(request) -def test_execute_errors_handled(monkeypatch): +def test_execute_errors_handled(monkeypatch): def mock_execute(): raise ValueError("Simulated error in execute") - + work_queue.put(ser_request) - + monkeypatch.setattr(integrated_worker, "execute", mock_execute) worker_manager._on_iteration() def test_fetch_model_errors_handled(monkeypatch): - def mock_fetch_model(a, b): - raise ValueError("Simulated error in fetch_model") - + raise ValueError("Simulated error in fetch_model") + work_queue.put(ser_request) - + monkeypatch.setattr(integrated_worker, "fetch_model", mock_fetch_model) worker_manager._on_iteration() def test_load_model_errors_handled(monkeypatch): - def mock_load_model(a, b): - raise ValueError("Simulated error in load_model") - + raise ValueError("Simulated error in load_model") + work_queue.put(ser_request) - + monkeypatch.setattr(integrated_worker, "load_model", mock_load_model) worker_manager._on_iteration() def test_fetch_inputs_errors_handled(monkeypatch): - def mock_fetch_inputs(a, b): - raise ValueError("Simulated error in fetch_inputs") - + raise ValueError("Simulated error in fetch_inputs") + work_queue.put(ser_request) - + monkeypatch.setattr(integrated_worker, "fetch_inputs", mock_fetch_inputs) worker_manager._on_iteration() def test_transform_input_errors_handled(monkeypatch): - def mock_transform_input(a, b): - raise ValueError("Simulated error in transform_input") - + raise ValueError("Simulated error in transform_input") + work_queue.put(ser_request) - + monkeypatch.setattr(integrated_worker, "transform_input", mock_transform_input) worker_manager._on_iteration() def test_transform_output_errors_handled(monkeypatch): - def mock_transform_output(a, b): - raise ValueError("Simulated error in transform_output") - + raise ValueError("Simulated error in transform_output") + work_queue.put(ser_request) - + monkeypatch.setattr(integrated_worker, "transform_output", mock_transform_output) worker_manager._on_iteration() -def test_place_output_errors_handled(monkeypatch): - +def test_place_output_errors_handled(monkeypatch, caplog): def mock_place_output(a, b, c): - raise ValueError("Simulated error in place_output") - + raise ValueError("Simulated error in place_output") + work_queue.put(ser_request) - + monkeypatch.setattr(integrated_worker, "place_output", mock_place_output) - worker_manager._on_iteration() \ No newline at end of file + worker_manager._on_iteration() + + with caplog.at_level(logging.ERROR): + worker_manager._on_iteration() + + # Check if the expected error message was logged + assert any("Simulated error in place_output" in message for message in caplog.text) From da3e0a790280c7cf6fe9f24c0d1baad0a25983e0 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Tue, 2 Jul 2024 10:37:44 -0700 Subject: [PATCH 05/57] building up replies --- .../infrastructure/control/workermanager.py | 27 ++++++++++++++----- .../_core/mli/infrastructure/worker/worker.py | 7 +++++ tests/mli/test_error_handling.py | 9 ++----- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 890cbb3e9c..7e0214a901 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -148,8 +148,8 @@ def build_reply(reply: InferenceReply) -> Response: results = prepare_outputs(reply) return MessageHandler.build_response( - status="complete", - message="success", + status=reply.status_enum, + message=reply.message, result=results, custom_attributes=None, ) @@ -275,6 +275,8 @@ def _on_iteration(self) -> None: f"Exception message: {str(e)}" ) reply.failed = True + reply.status_enum = "fail" + reply.message = "Failed while placing the output." else: reply.outputs = transformed_output.outputs except Exception as e: @@ -284,6 +286,8 @@ def _on_iteration(self) -> None: f"Exception message: {str(e)}" ) reply.failed = True + reply.status_enum = "fail" + reply.message = "Failed while transforming the output." except Exception as e: logger.exception( f"An error occurred while executing." @@ -291,6 +295,8 @@ def _on_iteration(self) -> None: f"Exception message: {str(e)}" ) reply.failed = True + reply.status_enum = "fail" + reply.message = "Failed while executing." except Exception as e: logger.exception( f"An error occurred while transforming the input." @@ -298,13 +304,17 @@ def _on_iteration(self) -> None: f"Exception message: {str(e)}" ) reply.failed = True + reply.status_enum = "fail" + reply.message = "Failed while transforming the input." except Exception as e: logger.exception( - f"An error occurred while fetching the model." + f"An error occurred while fetching the inputs." f"Exception type: {type(e).__name__}." f"Exception message: {str(e)}" ) reply.failed = True + reply.status_enum = "fail" + reply.message = "Failed while fetching the inputs." except Exception as e: logger.exception( f"An error occurred while loading the model." @@ -312,6 +322,8 @@ def _on_iteration(self) -> None: f"Exception message: {str(e)}" ) reply.failed = True + reply.status_enum = "fail" + reply.message = "Failed while loading the model." except Exception as e: logger.exception( f"An error occurred while fetching the model." @@ -319,14 +331,17 @@ def _on_iteration(self) -> None: f"Exception message: {str(e)}" ) reply.failed = True + reply.status_enum = "fail" + reply.message = "Failed while fetching the model." if reply.failed: - response = build_failure_reply("fail", "failure-occurred") + response = build_failure_reply(reply.status_enum, reply.message) else: if reply.outputs is None or not reply.outputs: - response = build_failure_reply("fail", "no-results") + response = build_failure_reply("fail", "Outputs not found.") - response = build_reply(reply) + else: + response = build_reply(reply) # serialized = self._worker.serialize_reply(request, transformed_output) serialized_resp = MessageHandler.serialize_response(response) # type: ignore diff --git a/smartsim/_core/mli/infrastructure/worker/worker.py b/smartsim/_core/mli/infrastructure/worker/worker.py index c87722b290..55bab9feef 100644 --- a/smartsim/_core/mli/infrastructure/worker/worker.py +++ b/smartsim/_core/mli/infrastructure/worker/worker.py @@ -32,6 +32,9 @@ from smartsim._core.mli.infrastructure.storage.featurestore import FeatureStore from smartsim.log import get_logger +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import StatusEnum + logger = get_logger(__name__) @@ -70,11 +73,15 @@ def __init__( outputs: t.Optional[t.Collection[t.Any]] = None, output_keys: t.Optional[t.Collection[str]] = None, failed: bool = False, + status_enum: 'StatusEnum' = "complete", + message: str = "Success" ) -> None: """Initialize the object""" self.outputs: t.Collection[t.Any] = outputs or [] self.output_keys: t.Collection[t.Optional[str]] = output_keys or [] self.failed = failed + self.status_enum = status_enum + self.message = message class LoadModelResult: diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 556e59c9d2..a59fc5f99d 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -122,7 +122,7 @@ def mock_transform_output(a, b): worker_manager._on_iteration() -def test_place_output_errors_handled(monkeypatch, caplog): +def test_place_output_errors_handled(monkeypatch): def mock_place_output(a, b, c): raise ValueError("Simulated error in place_output") @@ -130,9 +130,4 @@ def mock_place_output(a, b, c): monkeypatch.setattr(integrated_worker, "place_output", mock_place_output) worker_manager._on_iteration() - - with caplog.at_level(logging.ERROR): - worker_manager._on_iteration() - - # Check if the expected error message was logged - assert any("Simulated error in place_output" in message for message in caplog.text) + From d574fa16674b3aaa61b699753ba43b9311401a65 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Tue, 2 Jul 2024 13:53:39 -0700 Subject: [PATCH 06/57] replaced my nestedness with if checks, added helper function and tests --- .../infrastructure/control/workermanager.py | 157 ++++++++---------- .../_core/mli/infrastructure/worker/worker.py | 4 +- tests/mli/test_error_handling.py | 87 ++++++---- tests/mli/test_reply_building.py | 66 ++++++++ 4 files changed, 197 insertions(+), 117 deletions(-) create mode 100644 tests/mli/test_reply_building.py diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 7e0214a901..cd28d974f6 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -155,6 +155,19 @@ def build_reply(reply: InferenceReply) -> Response: ) +def exception_handler( + exc: Exception, func_descriptor: str, reply: InferenceReply +) -> None: + logger.exception( + f"An error occurred while {func_descriptor}." + f"Exception type: {type(exc).__name__}." + f"Exception message: {str(exc)}" + ) + reply.failed = True + reply.status_enum = "fail" + reply.message = f"Failed while {func_descriptor}." + + class WorkerManager(Service): """An implementation of a service managing distribution of tasks to machine learning workers""" @@ -241,98 +254,72 @@ def _on_iteration(self) -> None: reply = InferenceReply() + fetch_model_result = None + model_result = None + fetch_input_result = None + transformed_input = None + execute_result = None + transformed_output = None + try: fetch_model_result = self._worker.fetch_model(request, self._feature_store) + except Exception as e: + exception_handler(e, "fetching the model", reply) + + if not reply.failed and fetch_model_result is not None: try: model_result = self._worker.load_model(request, fetch_model_result) + except Exception as e: + exception_handler(e, "loading the model", reply) + + if not reply.failed: + try: + fetch_input_result = self._worker.fetch_inputs( + request, self._feature_store + ) + except Exception as e: + exception_handler(e, "fetching the inputs", reply) + + if not reply.failed and fetch_input_result is not None: + try: + transformed_input = self._worker.transform_input( + request, fetch_input_result + ) + except Exception as e: + exception_handler(e, "transforming the input", reply) + + if ( + not reply.failed + and model_result is not None + and transformed_input is not None + ): + try: + execute_result = self._worker.execute( + request, model_result, transformed_input + ) + except Exception as e: + exception_handler(e, "executing", reply) + + if not reply.failed and execute_result is not None: + try: + transformed_output = self._worker.transform_output( + request, execute_result + ) + except Exception as e: + exception_handler(e, "transforming the output", reply) + + if not reply.failed and transformed_output is not None: + if request.output_keys: try: - fetch_input_result = self._worker.fetch_inputs( - request, self._feature_store + reply.output_keys = self._worker.place_output( + request, + transformed_output, + self._feature_store, ) - try: - transformed_input = self._worker.transform_input( - request, fetch_input_result - ) - try: - execute_result = self._worker.execute( - request, model_result, transformed_input - ) - try: - transformed_output = self._worker.transform_output( - request, execute_result - ) - if request.output_keys: - try: - reply.output_keys = self._worker.place_output( - request, - transformed_output, - self._feature_store, - ) - except Exception as e: - logger.exception( - f"An error occurred while placing the output." - f"Exception type: {type(e).__name__}." - f"Exception message: {str(e)}" - ) - reply.failed = True - reply.status_enum = "fail" - reply.message = "Failed while placing the output." - else: - reply.outputs = transformed_output.outputs - except Exception as e: - logger.exception( - f"An error occurred while transforming the output." - f"Exception type: {type(e).__name__}." - f"Exception message: {str(e)}" - ) - reply.failed = True - reply.status_enum = "fail" - reply.message = "Failed while transforming the output." - except Exception as e: - logger.exception( - f"An error occurred while executing." - f"Exception type: {type(e).__name__}." - f"Exception message: {str(e)}" - ) - reply.failed = True - reply.status_enum = "fail" - reply.message = "Failed while executing." - except Exception as e: - logger.exception( - f"An error occurred while transforming the input." - f"Exception type: {type(e).__name__}." - f"Exception message: {str(e)}" - ) - reply.failed = True - reply.status_enum = "fail" - reply.message = "Failed while transforming the input." except Exception as e: - logger.exception( - f"An error occurred while fetching the inputs." - f"Exception type: {type(e).__name__}." - f"Exception message: {str(e)}" - ) - reply.failed = True - reply.status_enum = "fail" - reply.message = "Failed while fetching the inputs." - except Exception as e: - logger.exception( - f"An error occurred while loading the model." - f"Exception type: {type(e).__name__}." - f"Exception message: {str(e)}" - ) - reply.failed = True - reply.status_enum = "fail" - reply.message = "Failed while loading the model." - except Exception as e: - logger.exception( - f"An error occurred while fetching the model." - f"Exception type: {type(e).__name__}." - f"Exception message: {str(e)}" - ) - reply.failed = True - reply.status_enum = "fail" - reply.message = "Failed while fetching the model." + exception_handler(e, "placing the output", reply) + else: + reply.outputs = transformed_output.outputs if reply.failed: response = build_failure_reply(reply.status_enum, reply.message) diff --git a/smartsim/_core/mli/infrastructure/worker/worker.py b/smartsim/_core/mli/infrastructure/worker/worker.py index 55bab9feef..11522d298f 100644 --- a/smartsim/_core/mli/infrastructure/worker/worker.py +++ b/smartsim/_core/mli/infrastructure/worker/worker.py @@ -73,8 +73,8 @@ def __init__( outputs: t.Optional[t.Collection[t.Any]] = None, output_keys: t.Optional[t.Collection[str]] = None, failed: bool = False, - status_enum: 'StatusEnum' = "complete", - message: str = "Success" + status_enum: "StatusEnum" = "complete", + message: str = "Success", ) -> None: """Initialize the object""" self.outputs: t.Collection[t.Any] = outputs or [] diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index a59fc5f99d..46d803989a 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -24,43 +24,48 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import io -import logging import multiprocessing as mp import pytest -import torch -from smartsim._core.mli.infrastructure.control.workermanager import WorkerManager -from smartsim._core.mli.infrastructure.storage.featurestore import FeatureStore +from smartsim._core.mli.infrastructure.control.workermanager import ( + WorkerManager, + exception_handler, +) +from smartsim._core.mli.infrastructure.worker.worker import InferenceReply from smartsim._core.mli.message_handler import MessageHandler -from smartsim.log import get_logger from .channel import FileSystemCommChannel from .featurestore import FileSystemFeatureStore from .worker import IntegratedTorchWorker -work_queue: "mp.Queue[bytes]" = mp.Queue() -integrated_worker = IntegratedTorchWorker() -file_system_store = FileSystemFeatureStore() +@pytest.fixture +def setup_worker_manager(test_dir): + work_queue: "mp.Queue[bytes]" = mp.Queue() + integrated_worker = IntegratedTorchWorker() + file_system_store = FileSystemFeatureStore(test_dir) -worker_manager = WorkerManager( - work_queue, - integrated_worker, - file_system_store, - as_service=False, - cooldown=10, - comm_channel_type=FileSystemCommChannel, -) -tensor_key = MessageHandler.build_tensor_key("key") -request = MessageHandler.build_request( - b"channel", b"model", [tensor_key], [tensor_key], [], None -) -ser_request = MessageHandler.serialize_request(request) + worker_manager = WorkerManager( + work_queue, + integrated_worker, + file_system_store, + as_service=False, + cooldown=10, + comm_channel_type=FileSystemCommChannel, + ) + tensor_key = MessageHandler.build_tensor_key("key") + request = MessageHandler.build_request( + b"channel", b"model", [tensor_key], [tensor_key], [], None + ) + ser_request = MessageHandler.serialize_request(request) + + return worker_manager, work_queue, integrated_worker, ser_request -def test_execute_errors_handled(monkeypatch): +def test_execute_errors_handled(setup_worker_manager, monkeypatch): + worker_manager, work_queue, integrated_worker, ser_request = setup_worker_manager + def mock_execute(): raise ValueError("Simulated error in execute") @@ -71,7 +76,9 @@ def mock_execute(): worker_manager._on_iteration() -def test_fetch_model_errors_handled(monkeypatch): +def test_fetch_model_errors_handled(setup_worker_manager, monkeypatch): + worker_manager, work_queue, integrated_worker, ser_request = setup_worker_manager + def mock_fetch_model(a, b): raise ValueError("Simulated error in fetch_model") @@ -82,7 +89,9 @@ def mock_fetch_model(a, b): worker_manager._on_iteration() -def test_load_model_errors_handled(monkeypatch): +def test_load_model_errors_handled(setup_worker_manager, monkeypatch): + worker_manager, work_queue, integrated_worker, ser_request = setup_worker_manager + def mock_load_model(a, b): raise ValueError("Simulated error in load_model") @@ -92,7 +101,9 @@ def mock_load_model(a, b): worker_manager._on_iteration() -def test_fetch_inputs_errors_handled(monkeypatch): +def test_fetch_inputs_errors_handled(setup_worker_manager, monkeypatch): + worker_manager, work_queue, integrated_worker, ser_request = setup_worker_manager + def mock_fetch_inputs(a, b): raise ValueError("Simulated error in fetch_inputs") @@ -102,7 +113,9 @@ def mock_fetch_inputs(a, b): worker_manager._on_iteration() -def test_transform_input_errors_handled(monkeypatch): +def test_transform_input_errors_handled(setup_worker_manager, monkeypatch): + worker_manager, work_queue, integrated_worker, ser_request = setup_worker_manager + def mock_transform_input(a, b): raise ValueError("Simulated error in transform_input") @@ -112,7 +125,9 @@ def mock_transform_input(a, b): worker_manager._on_iteration() -def test_transform_output_errors_handled(monkeypatch): +def test_transform_output_errors_handled(setup_worker_manager, monkeypatch): + worker_manager, work_queue, integrated_worker, ser_request = setup_worker_manager + def mock_transform_output(a, b): raise ValueError("Simulated error in transform_output") @@ -122,7 +137,9 @@ def mock_transform_output(a, b): worker_manager._on_iteration() -def test_place_output_errors_handled(monkeypatch): +def test_place_output_errors_handled(setup_worker_manager, monkeypatch): + worker_manager, work_queue, integrated_worker, ser_request = setup_worker_manager + def mock_place_output(a, b, c): raise ValueError("Simulated error in place_output") @@ -130,4 +147,14 @@ def mock_place_output(a, b, c): monkeypatch.setattr(integrated_worker, "place_output", mock_place_output) worker_manager._on_iteration() - + + +def test_exception_handling_helper(): + reply = InferenceReply() + + test_exception = ValueError("Test ValueError") + exception_handler(test_exception, "fetching the model", reply) + + assert reply.failed == True + assert reply.status_enum == "fail" + assert reply.message == "Failed while fetching the model." diff --git a/tests/mli/test_reply_building.py b/tests/mli/test_reply_building.py new file mode 100644 index 0000000000..20f0de38e1 --- /dev/null +++ b/tests/mli/test_reply_building.py @@ -0,0 +1,66 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pytest + +from smartsim._core.mli.infrastructure.control.workermanager import ( + build_failure_reply, + build_reply, +) +from smartsim._core.mli.infrastructure.worker.worker import InferenceReply + + +@pytest.mark.parametrize( + "status, message", + [ + pytest.param("timeout", "Worker timed out", id="timeout"), + pytest.param("fail", "Failed while executing", id="fail"), + ], +) +def test_build_failure_reply(status, message): + response = build_failure_reply(status, message) + assert response.status == status + assert response.message == message + + +def test_build_failure_reply_fails(): + with pytest.raises(ValueError): + response = build_failure_reply("not a status enum", "message") + + +@pytest.mark.parametrize( + "status, message", + [ + pytest.param("complete", "Success", id="complete"), + ], +) +def test_build_reply(status, message): + reply = InferenceReply() + reply.status_enum = status + reply.message = message + response = build_reply(reply) + assert response.status == status + assert response.message == message From 4fa1184de2157259fb19bdf5f84965ce63a3c7a0 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Tue, 2 Jul 2024 15:47:02 -0700 Subject: [PATCH 07/57] doc string and newlines --- .../mli/infrastructure/control/workermanager.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index cd28d974f6..707eec0eea 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -158,9 +158,17 @@ def build_reply(reply: InferenceReply) -> Response: def exception_handler( exc: Exception, func_descriptor: str, reply: InferenceReply ) -> None: + """ + Logs exceptions and sets reply attributes without taking + down the WorkerManager. + + :param exc: The exception to be logged + :param func_descriptor: Descriptor to help form error messages + :param reply: InferenceReply to modify + """ logger.exception( - f"An error occurred while {func_descriptor}." - f"Exception type: {type(exc).__name__}." + f"An error occurred while {func_descriptor}.\n" + f"Exception type: {type(exc).__name__}.\n" f"Exception message: {str(exc)}" ) reply.failed = True From 5fa1a37d2354d2eef938e81539f8297e288898cb Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Tue, 2 Jul 2024 15:52:49 -0700 Subject: [PATCH 08/57] changelog --- doc/changelog.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/changelog.md b/doc/changelog.md index 9e6fb33e17..c963c92e21 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -13,6 +13,7 @@ Jump to: Description +- Add error handling in Worker Manager pipeline - Add ML worker manager, sample worker, and feature store - Added schemas and MessageHandler class for de/serialization of inference requests and response messages From 8709710695501d1731564361607438f3b32ca23f Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 3 Jul 2024 09:44:44 -0700 Subject: [PATCH 09/57] test groups --- tests/mli/test_error_handling.py | 3 +++ tests/mli/test_reply_building.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 46d803989a..3e04c47baf 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -39,6 +39,9 @@ from .featurestore import FileSystemFeatureStore from .worker import IntegratedTorchWorker +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + @pytest.fixture def setup_worker_manager(test_dir): diff --git a/tests/mli/test_reply_building.py b/tests/mli/test_reply_building.py index 20f0de38e1..c38bc69655 100644 --- a/tests/mli/test_reply_building.py +++ b/tests/mli/test_reply_building.py @@ -32,6 +32,8 @@ ) from smartsim._core.mli.infrastructure.worker.worker import InferenceReply +# The tests in this file belong to the group_b group +pytestmark = pytest.mark.group_b @pytest.mark.parametrize( "status, message", From 96bf89d8af400c18c10e2110d2daa90545823f3f Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 3 Jul 2024 09:45:10 -0700 Subject: [PATCH 10/57] style --- tests/mli/test_reply_building.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/mli/test_reply_building.py b/tests/mli/test_reply_building.py index c38bc69655..133caa89ae 100644 --- a/tests/mli/test_reply_building.py +++ b/tests/mli/test_reply_building.py @@ -35,6 +35,7 @@ # The tests in this file belong to the group_b group pytestmark = pytest.mark.group_b + @pytest.mark.parametrize( "status, message", [ From 9d8d51bb2baabb197e9074f291913f5f01e81c3b Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 3 Jul 2024 09:49:19 -0700 Subject: [PATCH 11/57] build reply fails test --- tests/mli/test_reply_building.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/mli/test_reply_building.py b/tests/mli/test_reply_building.py index 133caa89ae..93d1ae3224 100644 --- a/tests/mli/test_reply_building.py +++ b/tests/mli/test_reply_building.py @@ -67,3 +67,10 @@ def test_build_reply(status, message): response = build_reply(reply) assert response.status == status assert response.message == message + + +def test_build_reply_fails(): + with pytest.raises(ValueError): + reply = InferenceReply() + reply.status_enum = "not a status enum" + response = build_reply(reply) \ No newline at end of file From 9cff781bffe7922b4cfc09da52088e4f01a49669 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 3 Jul 2024 09:57:48 -0700 Subject: [PATCH 12/57] style --- tests/mli/test_reply_building.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mli/test_reply_building.py b/tests/mli/test_reply_building.py index 93d1ae3224..354d740eb2 100644 --- a/tests/mli/test_reply_building.py +++ b/tests/mli/test_reply_building.py @@ -73,4 +73,4 @@ def test_build_reply_fails(): with pytest.raises(ValueError): reply = InferenceReply() reply.status_enum = "not a status enum" - response = build_reply(reply) \ No newline at end of file + response = build_reply(reply) From 1e93392d68895eb90e9c9197fa266a3309da12ea Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Tue, 9 Jul 2024 16:21:41 -0700 Subject: [PATCH 13/57] add running enum, remove, failed parameter --- .../mli/infrastructure/control/workermanager.py | 17 +++++++++-------- .../_core/mli/infrastructure/worker/worker.py | 6 ++---- .../mli/mli_schemas/response/response.capnp | 1 + .../mli/mli_schemas/response/response_capnp.pyi | 2 +- tests/mli/test_error_handling.py | 3 +-- 5 files changed, 14 insertions(+), 15 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index b8396463a8..8cade35783 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -172,7 +172,6 @@ def exception_handler( f"Exception type: {type(exc).__name__}.\n" f"Exception message: {str(exc)}" ) - reply.failed = True reply.status_enum = "fail" reply.message = f"Failed while {func_descriptor}." @@ -275,13 +274,13 @@ def _on_iteration(self) -> None: except Exception as e: exception_handler(e, "fetching the model", reply) - if not reply.failed and fetch_model_result is not None: + if reply.status_enum == "running" and fetch_model_result is not None: try: model_result = self._worker.load_model(request, fetch_model_result) except Exception as e: exception_handler(e, "loading the model", reply) - if not reply.failed: + if reply.status_enum == "running": try: fetch_input_result = self._worker.fetch_inputs( request, self._feature_store @@ -289,7 +288,7 @@ def _on_iteration(self) -> None: except Exception as e: exception_handler(e, "fetching the inputs", reply) - if not reply.failed and fetch_input_result is not None: + if reply.status_enum == "running" and fetch_input_result is not None: try: transformed_input = self._worker.transform_input( request, fetch_input_result @@ -298,7 +297,7 @@ def _on_iteration(self) -> None: exception_handler(e, "transforming the input", reply) if ( - not reply.failed + reply.status_enum == "running" and model_result is not None and transformed_input is not None ): @@ -309,7 +308,7 @@ def _on_iteration(self) -> None: except Exception as e: exception_handler(e, "executing", reply) - if not reply.failed and execute_result is not None: + if reply.status_enum == "running" and execute_result is not None: try: transformed_output = self._worker.transform_output( request, execute_result @@ -317,7 +316,7 @@ def _on_iteration(self) -> None: except Exception as e: exception_handler(e, "transforming the output", reply) - if not reply.failed and transformed_output is not None: + if reply.status_enum == "running" and transformed_output is not None: if request.output_keys: try: reply.output_keys = self._worker.place_output( @@ -330,13 +329,15 @@ def _on_iteration(self) -> None: else: reply.outputs = transformed_output.outputs - if reply.failed: + if reply.status_enum != "running": response = build_failure_reply(reply.status_enum, reply.message) else: if reply.outputs is None or not reply.outputs: response = build_failure_reply("fail", "Outputs not found.") else: + reply.status_enum = "complete" + reply.message = "Success" response = build_reply(reply) # serialized = self._worker.serialize_reply(request, transformed_output) diff --git a/smartsim/_core/mli/infrastructure/worker/worker.py b/smartsim/_core/mli/infrastructure/worker/worker.py index da90260470..77781f2286 100644 --- a/smartsim/_core/mli/infrastructure/worker/worker.py +++ b/smartsim/_core/mli/infrastructure/worker/worker.py @@ -73,14 +73,12 @@ def __init__( self, outputs: t.Optional[t.Collection[t.Any]] = None, output_keys: t.Optional[t.Collection[str]] = None, - failed: bool = False, - status_enum: "StatusEnum" = "complete", - message: str = "Success", + status_enum: "StatusEnum" = "running", + message: str = "In progress", ) -> None: """Initialize the object""" self.outputs: t.Collection[t.Any] = outputs or [] self.output_keys: t.Collection[t.Optional[str]] = output_keys or [] - self.failed = failed self.status_enum = status_enum self.message = message diff --git a/smartsim/_core/mli/mli_schemas/response/response.capnp b/smartsim/_core/mli/mli_schemas/response/response.capnp index 67375b5a97..61df9d7104 100644 --- a/smartsim/_core/mli/mli_schemas/response/response.capnp +++ b/smartsim/_core/mli/mli_schemas/response/response.capnp @@ -34,6 +34,7 @@ enum StatusEnum { complete @0; fail @1; timeout @2; + running @3; } struct Response { diff --git a/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi b/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi index f6d7f8444e..e20f3b79ee 100644 --- a/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi +++ b/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi @@ -45,7 +45,7 @@ from .response_attributes.response_attributes_capnp import ( TorchResponseAttributesReader, ) -StatusEnum = Literal["complete", "fail", "timeout"] +StatusEnum = Literal["complete", "fail", "timeout", "running"] class Response: class Result: diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index bc61b8da2f..a470e6071e 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -54,7 +54,7 @@ def setup_worker_manager(test_dir): integrated_worker, file_system_store, as_service=False, - cooldown=10, + cooldown=3, comm_channel_type=FileSystemCommChannel, ) tensor_key = MessageHandler.build_tensor_key("key") @@ -159,6 +159,5 @@ def test_exception_handling_helper(): test_exception = ValueError("Test ValueError") exception_handler(test_exception, "fetching the model", reply) - assert reply.failed == True assert reply.status_enum == "fail" assert reply.message == "Failed while fetching the model." From ada632d8cc6c86ddeac5a7e6c224abb1e2dd5ae2 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 09:36:53 -0700 Subject: [PATCH 14/57] import or skip dragon --- tests/mli/test_error_handling.py | 6 ++++-- tests/mli/test_reply_building.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index a470e6071e..6d7b089c0d 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -28,6 +28,8 @@ import pytest +dragon = pytest.importorskip("dragon") + from smartsim._core.mli.infrastructure.control.workermanager import ( WorkerManager, exception_handler, @@ -39,8 +41,8 @@ from .featurestore import FileSystemFeatureStore from .worker import IntegratedTorchWorker -# The tests in this file belong to the group_a group -pytestmark = pytest.mark.group_a +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon @pytest.fixture diff --git a/tests/mli/test_reply_building.py b/tests/mli/test_reply_building.py index 354d740eb2..4370503af1 100644 --- a/tests/mli/test_reply_building.py +++ b/tests/mli/test_reply_building.py @@ -26,14 +26,16 @@ import pytest +dragon = pytest.importorskip("dragon") + from smartsim._core.mli.infrastructure.control.workermanager import ( build_failure_reply, build_reply, ) from smartsim._core.mli.infrastructure.worker.worker import InferenceReply -# The tests in this file belong to the group_b group -pytestmark = pytest.mark.group_b +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon @pytest.mark.parametrize( From 6bb16f5a094b318e3970bb5384118346837de2ca Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 10:00:51 -0700 Subject: [PATCH 15/57] fix workermanager init in tests --- tests/mli/test_error_handling.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 6d7b089c0d..7e1a772b94 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -30,12 +30,18 @@ dragon = pytest.importorskip("dragon") +import dragon.utils as du +from dragon.channels import Channel +from dragon.data.ddict.ddict import DDict +from dragon.fli import DragonFLIError, FLInterface + from smartsim._core.mli.infrastructure.control.workermanager import ( WorkerManager, exception_handler, ) from smartsim._core.mli.infrastructure.worker.worker import InferenceReply from smartsim._core.mli.message_handler import MessageHandler +from smartsim._core.mli.infrastructure.environmentloader import EnvironmentConfigLoader from .channel import FileSystemCommChannel from .featurestore import FileSystemFeatureStore @@ -46,15 +52,16 @@ @pytest.fixture -def setup_worker_manager(test_dir): - work_queue: "mp.Queue[bytes]" = mp.Queue() +def setup_worker_manager(test_dir, monkeypatch): integrated_worker = IntegratedTorchWorker() - file_system_store = FileSystemFeatureStore(test_dir) + + chan = Channel.make_process_local() + queue = FLInterface(main_ch=chan) + monkeypatch.setenv("SSQueue", du.B64.bytes_to_str(queue.serialize())) worker_manager = WorkerManager( - work_queue, + EnvironmentConfigLoader(), integrated_worker, - file_system_store, as_service=False, cooldown=3, comm_channel_type=FileSystemCommChannel, @@ -66,7 +73,7 @@ def setup_worker_manager(test_dir): ) ser_request = MessageHandler.serialize_request(request) - return worker_manager, work_queue, integrated_worker, ser_request + return worker_manager, worker_manager._task_queue, integrated_worker, ser_request def test_execute_errors_handled(setup_worker_manager, monkeypatch): From 6fa2144a88e12334cc1b5699bbe26da6cbab1be4 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 10:10:02 -0700 Subject: [PATCH 16/57] style --- tests/mli/test_error_handling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 7e1a772b94..3cb7685bcc 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -39,9 +39,9 @@ WorkerManager, exception_handler, ) +from smartsim._core.mli.infrastructure.environmentloader import EnvironmentConfigLoader from smartsim._core.mli.infrastructure.worker.worker import InferenceReply from smartsim._core.mli.message_handler import MessageHandler -from smartsim._core.mli.infrastructure.environmentloader import EnvironmentConfigLoader from .channel import FileSystemCommChannel from .featurestore import FileSystemFeatureStore From 9ed4c6bfc2b911753953b3e2c5761da1a2847ef6 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 11:31:26 -0700 Subject: [PATCH 17/57] fixing tests --- tests/mli/test_error_handling.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 3cb7685bcc..4b2b705984 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -66,24 +66,26 @@ def setup_worker_manager(test_dir, monkeypatch): cooldown=3, comm_channel_type=FileSystemCommChannel, ) + + tensor_key = MessageHandler.build_tensor_key("key") model = MessageHandler.build_model(b"model", "model name", "v 0.0.1") request = MessageHandler.build_request( b"channel", model, [tensor_key], [tensor_key], [], None ) ser_request = MessageHandler.serialize_request(request) + new_sender = worker_manager._task_queue.sendh(use_main_as_stream_channel=True) + new_sender.send_bytes(ser_request) - return worker_manager, worker_manager._task_queue, integrated_worker, ser_request + return worker_manager, integrated_worker def test_execute_errors_handled(setup_worker_manager, monkeypatch): - worker_manager, work_queue, integrated_worker, ser_request = setup_worker_manager + worker_manager, integrated_worker = setup_worker_manager def mock_execute(): raise ValueError("Simulated error in execute") - work_queue.put(ser_request) - monkeypatch.setattr(integrated_worker, "execute", mock_execute) worker_manager._on_iteration() From ee7b9d13efa2e6135ea90d2d4ab31d9488acf2a3 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 11:47:03 -0700 Subject: [PATCH 18/57] style --- tests/mli/test_error_handling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 4b2b705984..7340a81a85 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -67,7 +67,6 @@ def setup_worker_manager(test_dir, monkeypatch): comm_channel_type=FileSystemCommChannel, ) - tensor_key = MessageHandler.build_tensor_key("key") model = MessageHandler.build_model(b"model", "model name", "v 0.0.1") request = MessageHandler.build_request( From 7e41d9dcac27efb34872dc3024185eb0bcf82f9b Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 12:20:36 -0700 Subject: [PATCH 19/57] fix _on_iteration --- .../infrastructure/control/workermanager.py | 3 ++- tests/mli/test_error_handling.py | 23 +++++-------------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 6347d94bbe..c6334a0d3a 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -255,7 +255,8 @@ def _on_iteration(self) -> None: return # perform default deserialization of the message envelope - request_bytes: bytes = self._task_queue.get() + receiver = self._task_queue.recvh(use_main_as_stream_channel=True) + request_bytes, _ = receiver.recv_bytes() request = deserialize_message(request_bytes, self._comm_channel_type) if not self._validate_request(request): diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 7340a81a85..5cbb52e0ca 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -91,73 +91,62 @@ def mock_execute(): def test_fetch_model_errors_handled(setup_worker_manager, monkeypatch): - worker_manager, work_queue, integrated_worker, ser_request = setup_worker_manager + worker_manager, integrated_worker = setup_worker_manager def mock_fetch_model(a, b): raise ValueError("Simulated error in fetch_model") - work_queue.put(ser_request) - monkeypatch.setattr(integrated_worker, "fetch_model", mock_fetch_model) worker_manager._on_iteration() def test_load_model_errors_handled(setup_worker_manager, monkeypatch): - worker_manager, work_queue, integrated_worker, ser_request = setup_worker_manager + worker_manager, integrated_worker = setup_worker_manager def mock_load_model(a, b): raise ValueError("Simulated error in load_model") - work_queue.put(ser_request) - monkeypatch.setattr(integrated_worker, "load_model", mock_load_model) worker_manager._on_iteration() def test_fetch_inputs_errors_handled(setup_worker_manager, monkeypatch): - worker_manager, work_queue, integrated_worker, ser_request = setup_worker_manager + worker_manager, integrated_worker = setup_worker_manager def mock_fetch_inputs(a, b): raise ValueError("Simulated error in fetch_inputs") - work_queue.put(ser_request) - monkeypatch.setattr(integrated_worker, "fetch_inputs", mock_fetch_inputs) worker_manager._on_iteration() def test_transform_input_errors_handled(setup_worker_manager, monkeypatch): - worker_manager, work_queue, integrated_worker, ser_request = setup_worker_manager + worker_manager, integrated_worker = setup_worker_manager def mock_transform_input(a, b): raise ValueError("Simulated error in transform_input") - work_queue.put(ser_request) - monkeypatch.setattr(integrated_worker, "transform_input", mock_transform_input) worker_manager._on_iteration() def test_transform_output_errors_handled(setup_worker_manager, monkeypatch): - worker_manager, work_queue, integrated_worker, ser_request = setup_worker_manager + worker_manager, integrated_worker = setup_worker_manager def mock_transform_output(a, b): raise ValueError("Simulated error in transform_output") - work_queue.put(ser_request) - monkeypatch.setattr(integrated_worker, "transform_output", mock_transform_output) worker_manager._on_iteration() def test_place_output_errors_handled(setup_worker_manager, monkeypatch): - worker_manager, work_queue, integrated_worker, ser_request = setup_worker_manager + worker_manager, integrated_worker = setup_worker_manager def mock_place_output(a, b, c): raise ValueError("Simulated error in place_output") - work_queue.put(ser_request) monkeypatch.setattr(integrated_worker, "place_output", mock_place_output) worker_manager._on_iteration() From d2b7977890a9a244851568fbad3a2e2bbdca5e85 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 12:25:48 -0700 Subject: [PATCH 20/57] style --- tests/mli/test_error_handling.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 5cbb52e0ca..945fc164d0 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -147,7 +147,6 @@ def test_place_output_errors_handled(setup_worker_manager, monkeypatch): def mock_place_output(a, b, c): raise ValueError("Simulated error in place_output") - monkeypatch.setattr(integrated_worker, "place_output", mock_place_output) worker_manager._on_iteration() From 8b7924b7cfd9c330bd6e09a1a81c64d9c4f70ea1 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 13:51:34 -0700 Subject: [PATCH 21/57] return failure in reply channel asap --- .../infrastructure/control/workermanager.py | 124 ++++++++---------- 1 file changed, 55 insertions(+), 69 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index c6334a0d3a..00914a6c4f 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -160,7 +160,10 @@ def build_reply(reply: InferenceReply) -> Response: def exception_handler( - exc: Exception, func_descriptor: str, reply: InferenceReply + exc: Exception, + reply_channel: t.Optional[CommChannelBase], + func_descriptor: str, + reply: InferenceReply, ) -> None: """ Logs exceptions and sets reply attributes without taking @@ -177,6 +180,10 @@ def exception_handler( ) reply.status_enum = "fail" reply.message = f"Failed while {func_descriptor}." + response = build_failure_reply(reply.status_enum, reply.message) + serialized_resp = MessageHandler.serialize_response(response) # type: ignore + if reply_channel: + reply_channel.send(serialized_resp) class WorkerManager(Service): @@ -262,90 +269,69 @@ def _on_iteration(self) -> None: if not self._validate_request(request): return - # # let the worker perform additional custom deserialization - # request = self._worker.deserialize(request_bytes) - reply = InferenceReply() - fetch_model_result = None - model_result = None - fetch_input_result = None - transformed_input = None - execute_result = None - transformed_output = None - try: fetch_model_result = self._worker.fetch_model(request, self._feature_store) except Exception as e: - exception_handler(e, "fetching the model", reply) + exception_handler(e, request.callback, "fetching the model", reply) + return - if reply.status_enum == "running" and fetch_model_result is not None: - try: - model_result = self._worker.load_model(request, fetch_model_result) - except Exception as e: - exception_handler(e, "loading the model", reply) + try: + model_result = self._worker.load_model(request, fetch_model_result) + except Exception as e: + exception_handler(e, request.callback, "loading the model", reply) + return - if reply.status_enum == "running": - try: - fetch_input_result = self._worker.fetch_inputs( - request, self._feature_store - ) - except Exception as e: - exception_handler(e, "fetching the inputs", reply) + try: + fetch_input_result = self._worker.fetch_inputs(request, self._feature_store) + except Exception as e: + exception_handler(e, request.callback, "fetching the inputs", reply) + return - if reply.status_enum == "running" and fetch_input_result is not None: - try: - transformed_input = self._worker.transform_input( - request, fetch_input_result - ) - except Exception as e: - exception_handler(e, "transforming the input", reply) + try: + transformed_input = self._worker.transform_input( + request, fetch_input_result + ) + except Exception as e: + exception_handler(e, request.callback, "transforming the input", reply) + return - if ( - reply.status_enum == "running" - and model_result is not None - and transformed_input is not None - ): - try: - execute_result = self._worker.execute( - request, model_result, transformed_input - ) - except Exception as e: - exception_handler(e, "executing", reply) + try: + execute_result = self._worker.execute( + request, model_result, transformed_input + ) + except Exception as e: + exception_handler(e, request.callback, "executing", reply) + return + + try: + transformed_output = self._worker.transform_output(request, execute_result) + except Exception as e: + exception_handler(e, request.callback, "transforming the output", reply) + return - if reply.status_enum == "running" and execute_result is not None: + if request.output_keys: try: - transformed_output = self._worker.transform_output( - request, execute_result + reply.output_keys = self._worker.place_output( + request, + transformed_output, + self._feature_store, ) except Exception as e: - exception_handler(e, "transforming the output", reply) - - if reply.status_enum == "running" and transformed_output is not None: - if request.output_keys: - try: - reply.output_keys = self._worker.place_output( - request, - transformed_output, - self._feature_store, - ) - except Exception as e: - exception_handler(e, "placing the output", reply) - else: - reply.outputs = transformed_output.outputs - - if reply.status_enum != "running": - response = build_failure_reply(reply.status_enum, reply.message) + exception_handler(e, request.callback, "placing the output", reply) + return else: - if reply.outputs is None or not reply.outputs: - response = build_failure_reply("fail", "Outputs not found.") + reply.outputs = transformed_output.outputs - else: - reply.status_enum = "complete" - reply.message = "Success" - response = build_reply(reply) + if reply.outputs is None or not reply.outputs: + response = build_failure_reply("fail", "Outputs not found.") + + else: + reply.status_enum = "complete" + reply.message = "Success" + response = build_reply(reply) - # serialized = self._worker.serialize_reply(request, transformed_output) serialized_resp = MessageHandler.serialize_response(response) # type: ignore if request.callback: request.callback.send(serialized_resp) From ac2784a75fa3d5d574ae7c2a517c1743c691d1cc Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 14:01:43 -0700 Subject: [PATCH 22/57] fix test --- tests/mli/test_error_handling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 945fc164d0..ecf4461d88 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -155,7 +155,7 @@ def test_exception_handling_helper(): reply = InferenceReply() test_exception = ValueError("Test ValueError") - exception_handler(test_exception, "fetching the model", reply) + exception_handler(test_exception, None, "fetching the model", reply) assert reply.status_enum == "fail" assert reply.message == "Failed while fetching the model." From 2ab2d50f91c730e1fef3297219a5e4a0da7553eb Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 14:38:24 -0700 Subject: [PATCH 23/57] pr comments --- .../mli/infrastructure/control/workermanager.py | 5 +++-- tests/mli/test_error_handling.py | 16 ++++++++++++++++ tests/mli/test_reply_building.py | 12 ++++++++++-- 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 00914a6c4f..76a2ad1821 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -166,10 +166,11 @@ def exception_handler( reply: InferenceReply, ) -> None: """ - Logs exceptions and sets reply attributes without taking - down the WorkerManager. + Logs exceptions, sets reply attributes, and sends the + failure response without taking down the WorkerManager. :param exc: The exception to be logged + :param reply_channel: The channel used to send replies :param func_descriptor: Descriptor to help form error messages :param reply: InferenceReply to modify """ diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index ecf4461d88..4b00b48bbd 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -80,6 +80,8 @@ def setup_worker_manager(test_dir, monkeypatch): def test_execute_errors_handled(setup_worker_manager, monkeypatch): + """Ensures that the worker manager does not crash aftera failure in the + execute pipeline stage""" worker_manager, integrated_worker = setup_worker_manager def mock_execute(): @@ -91,6 +93,8 @@ def mock_execute(): def test_fetch_model_errors_handled(setup_worker_manager, monkeypatch): + """Ensures that the worker manager does not crash aftera failure in the + fetch model pipeline stage""" worker_manager, integrated_worker = setup_worker_manager def mock_fetch_model(a, b): @@ -102,6 +106,8 @@ def mock_fetch_model(a, b): def test_load_model_errors_handled(setup_worker_manager, monkeypatch): + """Ensures that the worker manager does not crash aftera failure in the + load model pipeline stage""" worker_manager, integrated_worker = setup_worker_manager def mock_load_model(a, b): @@ -112,6 +118,8 @@ def mock_load_model(a, b): def test_fetch_inputs_errors_handled(setup_worker_manager, monkeypatch): + """Ensures that the worker manager does not crash aftera failure in the + fetch inputs pipeline stage""" worker_manager, integrated_worker = setup_worker_manager def mock_fetch_inputs(a, b): @@ -122,6 +130,8 @@ def mock_fetch_inputs(a, b): def test_transform_input_errors_handled(setup_worker_manager, monkeypatch): + """Ensures that the worker manager does not crash aftera failure in the + transform input pipeline stage""" worker_manager, integrated_worker = setup_worker_manager def mock_transform_input(a, b): @@ -132,6 +142,8 @@ def mock_transform_input(a, b): def test_transform_output_errors_handled(setup_worker_manager, monkeypatch): + """Ensures that the worker manager does not crash aftera failure in the + transform output pipeline stage""" worker_manager, integrated_worker = setup_worker_manager def mock_transform_output(a, b): @@ -142,6 +154,8 @@ def mock_transform_output(a, b): def test_place_output_errors_handled(setup_worker_manager, monkeypatch): + """Ensures that the worker manager does not crash aftera failure in the + place output pipeline stage""" worker_manager, integrated_worker = setup_worker_manager def mock_place_output(a, b, c): @@ -152,6 +166,8 @@ def mock_place_output(a, b, c): def test_exception_handling_helper(): + """Ensures that the worker manager does not crash aftera failure in the + execute pipeline stage""" reply = InferenceReply() test_exception = ValueError("Test ValueError") diff --git a/tests/mli/test_reply_building.py b/tests/mli/test_reply_building.py index 4370503af1..d1fb565eee 100644 --- a/tests/mli/test_reply_building.py +++ b/tests/mli/test_reply_building.py @@ -46,15 +46,19 @@ ], ) def test_build_failure_reply(status, message): + "Ensures failure replies can be built successfully" response = build_failure_reply(status, message) assert response.status == status assert response.message == message def test_build_failure_reply_fails(): - with pytest.raises(ValueError): + "Ensures ValueError is raised if a StatusEnum is not used" + with pytest.raises(ValueError) as ex: response = build_failure_reply("not a status enum", "message") + assert "Error assigning status to response" in ex.value.args[0] + @pytest.mark.parametrize( "status, message", @@ -63,6 +67,7 @@ def test_build_failure_reply_fails(): ], ) def test_build_reply(status, message): + "Ensures replies can be built successfully" reply = InferenceReply() reply.status_enum = status reply.message = message @@ -72,7 +77,10 @@ def test_build_reply(status, message): def test_build_reply_fails(): - with pytest.raises(ValueError): + "Ensures ValueError is raised if a StatusEnum is not used" + with pytest.raises(ValueError) as ex: reply = InferenceReply() reply.status_enum = "not a status enum" response = build_reply(reply) + + assert "Error assigning status to response" in ex.value.args[0] \ No newline at end of file From a399fd515dfbac27883e0b74c2d794f9cdeb37d1 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 14:46:55 -0700 Subject: [PATCH 24/57] more pr comments --- .../infrastructure/control/workermanager.py | 2 +- tests/mli/test_error_handling.py | 26 ++++++++++++++----- tests/mli/test_reply_building.py | 11 +++++--- 3 files changed, 28 insertions(+), 11 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 76a2ad1821..50a73eb8c6 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -113,7 +113,7 @@ def deserialize_message( return inference_request -def build_failure_reply(status: "StatusEnum", message: str) -> Response: +def build_failure_reply(status: StatusEnum, message: str) -> Response: return MessageHandler.build_response( status=status, # todo: need to indicate correct status message=message, # todo: decide what these will be diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 4b00b48bbd..39616d5a47 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -79,7 +79,7 @@ def setup_worker_manager(test_dir, monkeypatch): return worker_manager, integrated_worker -def test_execute_errors_handled(setup_worker_manager, monkeypatch): +def test_execute_errors_handled(setup_worker_manager, monkeypatch: pytest.MonkeyPatch): """Ensures that the worker manager does not crash aftera failure in the execute pipeline stage""" worker_manager, integrated_worker = setup_worker_manager @@ -92,7 +92,9 @@ def mock_execute(): worker_manager._on_iteration() -def test_fetch_model_errors_handled(setup_worker_manager, monkeypatch): +def test_fetch_model_errors_handled( + setup_worker_manager, monkeypatch: pytest.MonkeyPatch +): """Ensures that the worker manager does not crash aftera failure in the fetch model pipeline stage""" worker_manager, integrated_worker = setup_worker_manager @@ -105,7 +107,9 @@ def mock_fetch_model(a, b): worker_manager._on_iteration() -def test_load_model_errors_handled(setup_worker_manager, monkeypatch): +def test_load_model_errors_handled( + setup_worker_manager, monkeypatch: pytest.MonkeyPatch +): """Ensures that the worker manager does not crash aftera failure in the load model pipeline stage""" worker_manager, integrated_worker = setup_worker_manager @@ -117,7 +121,9 @@ def mock_load_model(a, b): worker_manager._on_iteration() -def test_fetch_inputs_errors_handled(setup_worker_manager, monkeypatch): +def test_fetch_inputs_errors_handled( + setup_worker_manager, monkeypatch: pytest.MonkeyPatch +): """Ensures that the worker manager does not crash aftera failure in the fetch inputs pipeline stage""" worker_manager, integrated_worker = setup_worker_manager @@ -129,7 +135,9 @@ def mock_fetch_inputs(a, b): worker_manager._on_iteration() -def test_transform_input_errors_handled(setup_worker_manager, monkeypatch): +def test_transform_input_errors_handled( + setup_worker_manager, monkeypatch: pytest.MonkeyPatch +): """Ensures that the worker manager does not crash aftera failure in the transform input pipeline stage""" worker_manager, integrated_worker = setup_worker_manager @@ -141,7 +149,9 @@ def mock_transform_input(a, b): worker_manager._on_iteration() -def test_transform_output_errors_handled(setup_worker_manager, monkeypatch): +def test_transform_output_errors_handled( + setup_worker_manager, monkeypatch: pytest.MonkeyPatch +): """Ensures that the worker manager does not crash aftera failure in the transform output pipeline stage""" worker_manager, integrated_worker = setup_worker_manager @@ -153,7 +163,9 @@ def mock_transform_output(a, b): worker_manager._on_iteration() -def test_place_output_errors_handled(setup_worker_manager, monkeypatch): +def test_place_output_errors_handled( + setup_worker_manager, monkeypatch: pytest.MonkeyPatch +): """Ensures that the worker manager does not crash aftera failure in the place output pipeline stage""" worker_manager, integrated_worker = setup_worker_manager diff --git a/tests/mli/test_reply_building.py b/tests/mli/test_reply_building.py index d1fb565eee..d06a716e5d 100644 --- a/tests/mli/test_reply_building.py +++ b/tests/mli/test_reply_building.py @@ -24,6 +24,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import typing as t + import pytest dragon = pytest.importorskip("dragon") @@ -34,6 +36,9 @@ ) from smartsim._core.mli.infrastructure.worker.worker import InferenceReply +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import StatusEnum + # The tests in this file belong to the dragon group pytestmark = pytest.mark.dragon @@ -45,7 +50,7 @@ pytest.param("fail", "Failed while executing", id="fail"), ], ) -def test_build_failure_reply(status, message): +def test_build_failure_reply(status: StatusEnum, message: str): "Ensures failure replies can be built successfully" response = build_failure_reply(status, message) assert response.status == status @@ -66,7 +71,7 @@ def test_build_failure_reply_fails(): pytest.param("complete", "Success", id="complete"), ], ) -def test_build_reply(status, message): +def test_build_reply(status: StatusEnum, message: str): "Ensures replies can be built successfully" reply = InferenceReply() reply.status_enum = status @@ -83,4 +88,4 @@ def test_build_reply_fails(): reply.status_enum = "not a status enum" response = build_reply(reply) - assert "Error assigning status to response" in ex.value.args[0] \ No newline at end of file + assert "Error assigning status to response" in ex.value.args[0] From 17baec6eba4d1551ca7ccd6909c2a1fee0086e7e Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 14:55:31 -0700 Subject: [PATCH 25/57] typing --- smartsim/_core/mli/infrastructure/control/workermanager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 50a73eb8c6..76a2ad1821 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -113,7 +113,7 @@ def deserialize_message( return inference_request -def build_failure_reply(status: StatusEnum, message: str) -> Response: +def build_failure_reply(status: "StatusEnum", message: str) -> Response: return MessageHandler.build_response( status=status, # todo: need to indicate correct status message=message, # todo: decide what these will be From f71fd4b41f7c84cbc8e2ce164de6b65a887054b0 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 15:02:56 -0700 Subject: [PATCH 26/57] more typing --- tests/mli/test_reply_building.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/mli/test_reply_building.py b/tests/mli/test_reply_building.py index d06a716e5d..f585096788 100644 --- a/tests/mli/test_reply_building.py +++ b/tests/mli/test_reply_building.py @@ -50,7 +50,7 @@ pytest.param("fail", "Failed while executing", id="fail"), ], ) -def test_build_failure_reply(status: StatusEnum, message: str): +def test_build_failure_reply(status: "StatusEnum", message: str): "Ensures failure replies can be built successfully" response = build_failure_reply(status, message) assert response.status == status @@ -71,7 +71,7 @@ def test_build_failure_reply_fails(): pytest.param("complete", "Success", id="complete"), ], ) -def test_build_reply(status: StatusEnum, message: str): +def test_build_reply(status: "StatusEnum", message: str): "Ensures replies can be built successfully" reply = InferenceReply() reply.status_enum = status From 9aab212d1b5fa7ba5c799bace31d4bb2b57b058d Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 15:21:27 -0700 Subject: [PATCH 27/57] spelling --- tests/mli/test_error_handling.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 39616d5a47..2e8c3dd6a4 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -80,7 +80,7 @@ def setup_worker_manager(test_dir, monkeypatch): def test_execute_errors_handled(setup_worker_manager, monkeypatch: pytest.MonkeyPatch): - """Ensures that the worker manager does not crash aftera failure in the + """Ensures that the worker manager does not crash after a failure in the execute pipeline stage""" worker_manager, integrated_worker = setup_worker_manager @@ -95,7 +95,7 @@ def mock_execute(): def test_fetch_model_errors_handled( setup_worker_manager, monkeypatch: pytest.MonkeyPatch ): - """Ensures that the worker manager does not crash aftera failure in the + """Ensures that the worker manager does not crash after a failure in the fetch model pipeline stage""" worker_manager, integrated_worker = setup_worker_manager @@ -110,7 +110,7 @@ def mock_fetch_model(a, b): def test_load_model_errors_handled( setup_worker_manager, monkeypatch: pytest.MonkeyPatch ): - """Ensures that the worker manager does not crash aftera failure in the + """Ensures that the worker manager does not crash after a failure in the load model pipeline stage""" worker_manager, integrated_worker = setup_worker_manager @@ -124,7 +124,7 @@ def mock_load_model(a, b): def test_fetch_inputs_errors_handled( setup_worker_manager, monkeypatch: pytest.MonkeyPatch ): - """Ensures that the worker manager does not crash aftera failure in the + """Ensures that the worker manager does not crash after a failure in the fetch inputs pipeline stage""" worker_manager, integrated_worker = setup_worker_manager @@ -138,7 +138,7 @@ def mock_fetch_inputs(a, b): def test_transform_input_errors_handled( setup_worker_manager, monkeypatch: pytest.MonkeyPatch ): - """Ensures that the worker manager does not crash aftera failure in the + """Ensures that the worker manager does not crash after a failure in the transform input pipeline stage""" worker_manager, integrated_worker = setup_worker_manager @@ -152,7 +152,7 @@ def mock_transform_input(a, b): def test_transform_output_errors_handled( setup_worker_manager, monkeypatch: pytest.MonkeyPatch ): - """Ensures that the worker manager does not crash aftera failure in the + """Ensures that the worker manager does not crash after a failure in the transform output pipeline stage""" worker_manager, integrated_worker = setup_worker_manager @@ -166,7 +166,7 @@ def mock_transform_output(a, b): def test_place_output_errors_handled( setup_worker_manager, monkeypatch: pytest.MonkeyPatch ): - """Ensures that the worker manager does not crash aftera failure in the + """Ensures that the worker manager does not crash after a failure in the place output pipeline stage""" worker_manager, integrated_worker = setup_worker_manager @@ -178,7 +178,7 @@ def mock_place_output(a, b, c): def test_exception_handling_helper(): - """Ensures that the worker manager does not crash aftera failure in the + """Ensures that the worker manager does not crash after a failure in the execute pipeline stage""" reply = InferenceReply() From 7998843ddf99a1d0027a15abeed7a5522651acb5 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 15:34:56 -0700 Subject: [PATCH 28/57] mock test --- tests/mli/test_error_handling.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 2e8c3dd6a4..97b2804b66 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -24,7 +24,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import multiprocessing as mp +from unittest.mock import MagicMock import pytest @@ -91,6 +91,16 @@ def mock_execute(): worker_manager._on_iteration() + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.workermanager", + "build_failure_reply", + mock_reply_fn, + ) + + assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_with("fail", "Failed while executing.") + def test_fetch_model_errors_handled( setup_worker_manager, monkeypatch: pytest.MonkeyPatch From a8a6e292a5e423cccf47eacfe917592d25055787 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 15:45:05 -0700 Subject: [PATCH 29/57] fix test --- tests/mli/test_error_handling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 97b2804b66..082df15e0a 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -93,7 +93,7 @@ def mock_execute(): mock_reply_fn = MagicMock() monkeypatch.setattr( - "smartsim._core.mli.infrastructure.workermanager", + worker_manager, "build_failure_reply", mock_reply_fn, ) From f7399ef362d65959e26783320b13c9b1562e3209 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 15:58:56 -0700 Subject: [PATCH 30/57] fix test again --- tests/mli/test_error_handling.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 082df15e0a..c4ff4c392d 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -32,11 +32,11 @@ import dragon.utils as du from dragon.channels import Channel -from dragon.data.ddict.ddict import DDict -from dragon.fli import DragonFLIError, FLInterface +from dragon.fli import FLInterface from smartsim._core.mli.infrastructure.control.workermanager import ( WorkerManager, + build_failure_reply, exception_handler, ) from smartsim._core.mli.infrastructure.environmentloader import EnvironmentConfigLoader @@ -93,8 +93,7 @@ def mock_execute(): mock_reply_fn = MagicMock() monkeypatch.setattr( - worker_manager, - "build_failure_reply", + build_failure_reply, mock_reply_fn, ) From a8dae656d51584f54875a2b6aa6d4446a500a00b Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 16:11:32 -0700 Subject: [PATCH 31/57] try again --- tests/mli/test_error_handling.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index c4ff4c392d..3a08e84fd1 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -36,7 +36,6 @@ from smartsim._core.mli.infrastructure.control.workermanager import ( WorkerManager, - build_failure_reply, exception_handler, ) from smartsim._core.mli.infrastructure.environmentloader import EnvironmentConfigLoader @@ -93,7 +92,7 @@ def mock_execute(): mock_reply_fn = MagicMock() monkeypatch.setattr( - build_failure_reply, + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", mock_reply_fn, ) From 6e49886be37fe393929d534f3ef813e1996290cd Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 16:31:36 -0700 Subject: [PATCH 32/57] oops add feature store --- tests/mli/test_error_handling.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 3a08e84fd1..5c0fc0f0ee 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -24,6 +24,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import base64 +import pickle from unittest.mock import MagicMock import pytest @@ -32,6 +34,7 @@ import dragon.utils as du from dragon.channels import Channel +from dragon.data.ddict.ddict import DDict from dragon.fli import FLInterface from smartsim._core.mli.infrastructure.control.workermanager import ( @@ -39,6 +42,9 @@ exception_handler, ) from smartsim._core.mli.infrastructure.environmentloader import EnvironmentConfigLoader +from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import ( + DragonFeatureStore, +) from smartsim._core.mli.infrastructure.worker.worker import InferenceReply from smartsim._core.mli.message_handler import MessageHandler @@ -57,6 +63,11 @@ def setup_worker_manager(test_dir, monkeypatch): chan = Channel.make_process_local() queue = FLInterface(main_ch=chan) monkeypatch.setenv("SSQueue", du.B64.bytes_to_str(queue.serialize())) + storage = DDict() + feature_store = DragonFeatureStore(storage) + monkeypatch.setenv( + "SSFeatureStore", base64.b64encode(pickle.dumps(feature_store)).decode("utf-8") + ) worker_manager = WorkerManager( EnvironmentConfigLoader(), From 2b1f5d409a3452c383a81d13c5d74d446a30c3a9 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 16:46:49 -0700 Subject: [PATCH 33/57] add mock return values for previous functions --- tests/mli/test_error_handling.py | 51 ++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 23 deletions(-) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 5c0fc0f0ee..0969f126b4 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -45,7 +45,7 @@ from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import ( DragonFeatureStore, ) -from smartsim._core.mli.infrastructure.worker.worker import InferenceReply +from smartsim._core.mli.infrastructure.worker.worker import InferenceReply, FetchInputResult, FetchModelResult, LoadModelResult, TransformInputResult, ExecuteResult, TransformOutputResult from smartsim._core.mli.message_handler import MessageHandler from .channel import FileSystemCommChannel @@ -89,28 +89,6 @@ def setup_worker_manager(test_dir, monkeypatch): return worker_manager, integrated_worker -def test_execute_errors_handled(setup_worker_manager, monkeypatch: pytest.MonkeyPatch): - """Ensures that the worker manager does not crash after a failure in the - execute pipeline stage""" - worker_manager, integrated_worker = setup_worker_manager - - def mock_execute(): - raise ValueError("Simulated error in execute") - - monkeypatch.setattr(integrated_worker, "execute", mock_execute) - - worker_manager._on_iteration() - - mock_reply_fn = MagicMock() - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", - mock_reply_fn, - ) - - assert mock_reply_fn.called_once() - mock_reply_fn.assert_called_with("fail", "Failed while executing.") - - def test_fetch_model_errors_handled( setup_worker_manager, monkeypatch: pytest.MonkeyPatch ): @@ -168,6 +146,33 @@ def mock_transform_input(a, b): worker_manager._on_iteration() +def test_execute_errors_handled(setup_worker_manager, monkeypatch: pytest.MonkeyPatch): + """Ensures that the worker manager does not crash after a failure in the + execute pipeline stage""" + worker_manager, integrated_worker = setup_worker_manager + + monkeypatch.setattr(integrated_worker, "fetch_model", MagicMock(return_value=FetchModelResult(b"result_bytes"))) + monkeypatch.setattr(integrated_worker, "load_model", MagicMock(return_value=LoadModelResult(b"result_bytes"))) + monkeypatch.setattr(integrated_worker, "fetch_inputs", MagicMock(return_value=FetchInputResult(b"result_bytes"))) + monkeypatch.setattr(integrated_worker, "transform_input", MagicMock(return_value=TransformInputResult(b"result_bytes"))) + + def mock_execute(): + raise ValueError("Simulated error in execute") + + monkeypatch.setattr(integrated_worker, "execute", mock_execute) + + worker_manager._on_iteration() + + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, + ) + + assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_with("fail", "Failed while executing.") + + def test_transform_output_errors_handled( setup_worker_manager, monkeypatch: pytest.MonkeyPatch ): From b1e385b49b0e2a752fa21e94be89d552bfa3691f Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 16:55:31 -0700 Subject: [PATCH 34/57] style --- tests/mli/test_error_handling.py | 34 +++++++++++++++++++++++++++----- 1 file changed, 29 insertions(+), 5 deletions(-) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 0969f126b4..e2de9a45d4 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -45,7 +45,15 @@ from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import ( DragonFeatureStore, ) -from smartsim._core.mli.infrastructure.worker.worker import InferenceReply, FetchInputResult, FetchModelResult, LoadModelResult, TransformInputResult, ExecuteResult, TransformOutputResult +from smartsim._core.mli.infrastructure.worker.worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + InferenceReply, + LoadModelResult, + TransformInputResult, + TransformOutputResult, +) from smartsim._core.mli.message_handler import MessageHandler from .channel import FileSystemCommChannel @@ -151,10 +159,26 @@ def test_execute_errors_handled(setup_worker_manager, monkeypatch: pytest.Monkey execute pipeline stage""" worker_manager, integrated_worker = setup_worker_manager - monkeypatch.setattr(integrated_worker, "fetch_model", MagicMock(return_value=FetchModelResult(b"result_bytes"))) - monkeypatch.setattr(integrated_worker, "load_model", MagicMock(return_value=LoadModelResult(b"result_bytes"))) - monkeypatch.setattr(integrated_worker, "fetch_inputs", MagicMock(return_value=FetchInputResult(b"result_bytes"))) - monkeypatch.setattr(integrated_worker, "transform_input", MagicMock(return_value=TransformInputResult(b"result_bytes"))) + monkeypatch.setattr( + integrated_worker, + "fetch_model", + MagicMock(return_value=FetchModelResult(b"result_bytes")), + ) + monkeypatch.setattr( + integrated_worker, + "load_model", + MagicMock(return_value=LoadModelResult(b"result_bytes")), + ) + monkeypatch.setattr( + integrated_worker, + "fetch_inputs", + MagicMock(return_value=FetchInputResult([b"result_bytes"])), + ) + monkeypatch.setattr( + integrated_worker, + "transform_input", + MagicMock(return_value=TransformInputResult(b"result_bytes")), + ) def mock_execute(): raise ValueError("Simulated error in execute") From e64c0e62eb6db23e05a30a4d0ec16e27be4dd31f Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 10 Jul 2024 17:06:29 -0700 Subject: [PATCH 35/57] positional args --- tests/mli/test_error_handling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index e2de9a45d4..2bd71d7b22 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -180,7 +180,7 @@ def test_execute_errors_handled(setup_worker_manager, monkeypatch: pytest.Monkey MagicMock(return_value=TransformInputResult(b"result_bytes")), ) - def mock_execute(): + def mock_execute(a, b, c): raise ValueError("Simulated error in execute") monkeypatch.setattr(integrated_worker, "execute", mock_execute) From 148da63f168c3920f97d74a01a234091d54cad2f Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Thu, 11 Jul 2024 12:52:47 -0500 Subject: [PATCH 36/57] mock tests --- .../infrastructure/control/workermanager.py | 7 +- tests/mli/test_error_handling.py | 197 ++++++++++++++++-- 2 files changed, 187 insertions(+), 17 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 76a2ad1821..fc87fb1431 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -86,6 +86,7 @@ def deserialize_message( input_bytes: t.Optional[t.List[bytes]] = ( None # these will really be tensors already ) + output_keys: t.Optional[t.List[str]] = None # # client example # msg = Message() @@ -101,12 +102,16 @@ def deserialize_message( input_bytes = [data.blob for data in request.input.data] input_meta = [data.tensorDescriptor for data in request.input.data] + if request.output: + output_keys = [tensor_key.key for tensor_key in request.output] + inference_request = InferenceRequest( model_key=model_key, callback=comm_channel, raw_inputs=input_bytes, - input_meta=input_meta, input_keys=input_keys, + input_meta=input_meta, + output_keys=output_keys, raw_model=model_bytes, batch_size=0, ) diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py index 2bd71d7b22..74b78bb863 100644 --- a/tests/mli/test_error_handling.py +++ b/tests/mli/test_error_handling.py @@ -48,7 +48,6 @@ from smartsim._core.mli.infrastructure.worker.worker import ( ExecuteResult, FetchInputResult, - FetchModelResult, InferenceReply, LoadModelResult, TransformInputResult, @@ -57,7 +56,6 @@ from smartsim._core.mli.message_handler import MessageHandler from .channel import FileSystemCommChannel -from .featurestore import FileSystemFeatureStore from .worker import IntegratedTorchWorker # The tests in this file belong to the dragon group @@ -108,9 +106,25 @@ def mock_fetch_model(a, b): raise ValueError("Simulated error in fetch_model") monkeypatch.setattr(integrated_worker, "fetch_model", mock_fetch_model) + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, + ) + + def mock_exception_handler(exc, reply_channel, func_descriptor, reply): + return exception_handler(exc, None, func_descriptor, reply) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", + mock_exception_handler, + ) worker_manager._on_iteration() + assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_with("fail", "Failed while fetching the model.") + def test_load_model_errors_handled( setup_worker_manager, monkeypatch: pytest.MonkeyPatch @@ -123,8 +137,25 @@ def mock_load_model(a, b): raise ValueError("Simulated error in load_model") monkeypatch.setattr(integrated_worker, "load_model", mock_load_model) + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, + ) + + def mock_exception_handler(exc, reply_channel, func_descriptor, reply): + return exception_handler(exc, None, func_descriptor, reply) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", + mock_exception_handler, + ) + worker_manager._on_iteration() + assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_with("fail", "Failed while loading the model.") + def test_fetch_inputs_errors_handled( setup_worker_manager, monkeypatch: pytest.MonkeyPatch @@ -137,8 +168,31 @@ def mock_fetch_inputs(a, b): raise ValueError("Simulated error in fetch_inputs") monkeypatch.setattr(integrated_worker, "fetch_inputs", mock_fetch_inputs) + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, + ) + + def mock_exception_handler(exc, reply_channel, func_descriptor, reply): + return exception_handler(exc, None, func_descriptor, reply) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", + mock_exception_handler, + ) + + monkeypatch.setattr( + integrated_worker, + "load_model", + MagicMock(return_value=LoadModelResult(b"result_bytes")), + ) + worker_manager._on_iteration() + assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_with("fail", "Failed while fetching the inputs.") + def test_transform_input_errors_handled( setup_worker_manager, monkeypatch: pytest.MonkeyPatch @@ -151,19 +205,60 @@ def mock_transform_input(a, b): raise ValueError("Simulated error in transform_input") monkeypatch.setattr(integrated_worker, "transform_input", mock_transform_input) + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, + ) + + def mock_exception_handler(exc, reply_channel, func_descriptor, reply): + return exception_handler(exc, None, func_descriptor, reply) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", + mock_exception_handler, + ) + + monkeypatch.setattr( + integrated_worker, + "load_model", + MagicMock(return_value=LoadModelResult(b"result_bytes")), + ) + monkeypatch.setattr( + integrated_worker, + "fetch_inputs", + MagicMock(return_value=FetchInputResult([b"result_bytes"])), + ) + worker_manager._on_iteration() + assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_with("fail", "Failed while transforming the input.") + def test_execute_errors_handled(setup_worker_manager, monkeypatch: pytest.MonkeyPatch): """Ensures that the worker manager does not crash after a failure in the execute pipeline stage""" worker_manager, integrated_worker = setup_worker_manager + def mock_execute(a, b, c): + raise ValueError("Simulated error in execute") + + monkeypatch.setattr(integrated_worker, "execute", mock_execute) + mock_reply_fn = MagicMock() monkeypatch.setattr( - integrated_worker, - "fetch_model", - MagicMock(return_value=FetchModelResult(b"result_bytes")), + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, ) + + def mock_exception_handler(exc, reply_channel, func_descriptor, reply): + return exception_handler(exc, None, func_descriptor, reply) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", + mock_exception_handler, + ) + monkeypatch.setattr( integrated_worker, "load_model", @@ -180,19 +275,8 @@ def test_execute_errors_handled(setup_worker_manager, monkeypatch: pytest.Monkey MagicMock(return_value=TransformInputResult(b"result_bytes")), ) - def mock_execute(a, b, c): - raise ValueError("Simulated error in execute") - - monkeypatch.setattr(integrated_worker, "execute", mock_execute) - worker_manager._on_iteration() - mock_reply_fn = MagicMock() - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", - mock_reply_fn, - ) - assert mock_reply_fn.called_once() mock_reply_fn.assert_called_with("fail", "Failed while executing.") @@ -208,8 +292,46 @@ def mock_transform_output(a, b): raise ValueError("Simulated error in transform_output") monkeypatch.setattr(integrated_worker, "transform_output", mock_transform_output) + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, + ) + + def mock_exception_handler(exc, reply_channel, func_descriptor, reply): + return exception_handler(exc, None, func_descriptor, reply) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", + mock_exception_handler, + ) + + monkeypatch.setattr( + integrated_worker, + "load_model", + MagicMock(return_value=LoadModelResult(b"result_bytes")), + ) + monkeypatch.setattr( + integrated_worker, + "fetch_inputs", + MagicMock(return_value=FetchInputResult([b"result_bytes"])), + ) + monkeypatch.setattr( + integrated_worker, + "transform_input", + MagicMock(return_value=TransformInputResult(b"result_bytes")), + ) + monkeypatch.setattr( + integrated_worker, + "execute", + MagicMock(return_value=ExecuteResult(b"result_bytes")), + ) + worker_manager._on_iteration() + assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_with("fail", "Failed while transforming the output.") + def test_place_output_errors_handled( setup_worker_manager, monkeypatch: pytest.MonkeyPatch @@ -222,8 +344,51 @@ def mock_place_output(a, b, c): raise ValueError("Simulated error in place_output") monkeypatch.setattr(integrated_worker, "place_output", mock_place_output) + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, + ) + + def mock_exception_handler(exc, reply_channel, func_descriptor, reply): + return exception_handler(exc, None, func_descriptor, reply) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", + mock_exception_handler, + ) + + monkeypatch.setattr( + integrated_worker, + "load_model", + MagicMock(return_value=LoadModelResult(b"result_bytes")), + ) + monkeypatch.setattr( + integrated_worker, + "fetch_inputs", + MagicMock(return_value=FetchInputResult([b"result_bytes"])), + ) + monkeypatch.setattr( + integrated_worker, + "transform_input", + MagicMock(return_value=TransformInputResult(b"result_bytes")), + ) + monkeypatch.setattr( + integrated_worker, + "execute", + MagicMock(return_value=ExecuteResult(b"result_bytes")), + ) + monkeypatch.setattr( + integrated_worker, + "transform_output", + MagicMock(return_value=TransformOutputResult(b"result", [], "c", "float32")), + ) + worker_manager._on_iteration() + assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_with("fail", "Failed while placing the output.") + def test_exception_handling_helper(): """Ensures that the worker manager does not crash after a failure in the From f30bc37ddcfd22359f298389cc67dfafd9f9ee05 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Thu, 11 Jul 2024 13:29:53 -0500 Subject: [PATCH 37/57] moving tests --- tests/dragon/test_error_handling.py | 402 ++++++++++++++++++++++++++++ tests/dragon/test_reply_building.py | 91 +++++++ tests/dragon/utils/channel.py | 59 ++++ tests/dragon/utils/worker.py | 128 +++++++++ 4 files changed, 680 insertions(+) create mode 100644 tests/dragon/test_error_handling.py create mode 100644 tests/dragon/test_reply_building.py create mode 100644 tests/dragon/utils/channel.py create mode 100644 tests/dragon/utils/worker.py diff --git a/tests/dragon/test_error_handling.py b/tests/dragon/test_error_handling.py new file mode 100644 index 0000000000..799e1734ba --- /dev/null +++ b/tests/dragon/test_error_handling.py @@ -0,0 +1,402 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import base64 +import pickle +from unittest.mock import MagicMock + +import pytest + +dragon = pytest.importorskip("dragon") + +import dragon.utils as du +from dragon.channels import Channel +from dragon.data.ddict.ddict import DDict +from dragon.fli import FLInterface + +from smartsim._core.mli.infrastructure.control.workermanager import ( + WorkerManager, + exception_handler, +) +from smartsim._core.mli.infrastructure.environmentloader import EnvironmentConfigLoader +from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import ( + DragonFeatureStore, +) +from smartsim._core.mli.infrastructure.worker.worker import ( + ExecuteResult, + FetchInputResult, + InferenceReply, + LoadModelResult, + TransformInputResult, + TransformOutputResult, +) +from smartsim._core.mli.message_handler import MessageHandler + +from .utils.channel import FileSystemCommChannel +from .utils.worker import IntegratedTorchWorker + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +@pytest.fixture +def setup_worker_manager(test_dir, monkeypatch): + integrated_worker = IntegratedTorchWorker() + + chan = Channel.make_process_local() + queue = FLInterface(main_ch=chan) + monkeypatch.setenv("SSQueue", du.B64.bytes_to_str(queue.serialize())) + storage = DDict() + feature_store = DragonFeatureStore(storage) + monkeypatch.setenv( + "SSFeatureStore", base64.b64encode(pickle.dumps(feature_store)).decode("utf-8") + ) + + worker_manager = WorkerManager( + EnvironmentConfigLoader(), + integrated_worker, + as_service=False, + cooldown=3, + comm_channel_type=FileSystemCommChannel, + ) + + tensor_key = MessageHandler.build_tensor_key("key") + model = MessageHandler.build_model(b"model", "model name", "v 0.0.1") + request = MessageHandler.build_request( + b"channel", model, [tensor_key], [tensor_key], [], None + ) + ser_request = MessageHandler.serialize_request(request) + new_sender = worker_manager._task_queue.sendh(use_main_as_stream_channel=True) + new_sender.send_bytes(ser_request) + + return worker_manager, integrated_worker + + +def test_fetch_model_errors_handled( + setup_worker_manager, monkeypatch: pytest.MonkeyPatch +): + """Ensures that the worker manager does not crash after a failure in the + fetch model pipeline stage""" + worker_manager, integrated_worker = setup_worker_manager + + def mock_fetch_model(a, b): + raise ValueError("Simulated error in fetch_model") + + monkeypatch.setattr(integrated_worker, "fetch_model", mock_fetch_model) + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, + ) + + def mock_exception_handler(exc, reply_channel, func_descriptor, reply): + return exception_handler(exc, None, func_descriptor, reply) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", + mock_exception_handler, + ) + + worker_manager._on_iteration() + + assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_with("fail", "Failed while fetching the model.") + + +def test_load_model_errors_handled( + setup_worker_manager, monkeypatch: pytest.MonkeyPatch +): + """Ensures that the worker manager does not crash after a failure in the + load model pipeline stage""" + worker_manager, integrated_worker = setup_worker_manager + + def mock_load_model(a, b): + raise ValueError("Simulated error in load_model") + + monkeypatch.setattr(integrated_worker, "load_model", mock_load_model) + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, + ) + + def mock_exception_handler(exc, reply_channel, func_descriptor, reply): + return exception_handler(exc, None, func_descriptor, reply) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", + mock_exception_handler, + ) + + worker_manager._on_iteration() + + assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_with("fail", "Failed while loading the model.") + + +def test_fetch_inputs_errors_handled( + setup_worker_manager, monkeypatch: pytest.MonkeyPatch +): + """Ensures that the worker manager does not crash after a failure in the + fetch inputs pipeline stage""" + worker_manager, integrated_worker = setup_worker_manager + + def mock_fetch_inputs(a, b): + raise ValueError("Simulated error in fetch_inputs") + + monkeypatch.setattr(integrated_worker, "fetch_inputs", mock_fetch_inputs) + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, + ) + + def mock_exception_handler(exc, reply_channel, func_descriptor, reply): + return exception_handler(exc, None, func_descriptor, reply) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", + mock_exception_handler, + ) + + monkeypatch.setattr( + integrated_worker, + "load_model", + MagicMock(return_value=LoadModelResult(b"result_bytes")), + ) + + worker_manager._on_iteration() + + assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_with("fail", "Failed while fetching the inputs.") + + +def test_transform_input_errors_handled( + setup_worker_manager, monkeypatch: pytest.MonkeyPatch +): + """Ensures that the worker manager does not crash after a failure in the + transform input pipeline stage""" + worker_manager, integrated_worker = setup_worker_manager + + def mock_transform_input(a, b): + raise ValueError("Simulated error in transform_input") + + monkeypatch.setattr(integrated_worker, "transform_input", mock_transform_input) + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, + ) + + def mock_exception_handler(exc, reply_channel, func_descriptor, reply): + return exception_handler(exc, None, func_descriptor, reply) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", + mock_exception_handler, + ) + + monkeypatch.setattr( + integrated_worker, + "load_model", + MagicMock(return_value=LoadModelResult(b"result_bytes")), + ) + monkeypatch.setattr( + integrated_worker, + "fetch_inputs", + MagicMock(return_value=FetchInputResult([b"result_bytes"])), + ) + + worker_manager._on_iteration() + + assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_with("fail", "Failed while transforming the input.") + + +def test_execute_errors_handled(setup_worker_manager, monkeypatch: pytest.MonkeyPatch): + """Ensures that the worker manager does not crash after a failure in the + execute pipeline stage""" + worker_manager, integrated_worker = setup_worker_manager + + def mock_execute(a, b, c): + raise ValueError("Simulated error in execute") + + monkeypatch.setattr(integrated_worker, "execute", mock_execute) + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, + ) + + def mock_exception_handler(exc, reply_channel, func_descriptor, reply): + return exception_handler(exc, None, func_descriptor, reply) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", + mock_exception_handler, + ) + + monkeypatch.setattr( + integrated_worker, + "load_model", + MagicMock(return_value=LoadModelResult(b"result_bytes")), + ) + monkeypatch.setattr( + integrated_worker, + "fetch_inputs", + MagicMock(return_value=FetchInputResult([b"result_bytes"])), + ) + monkeypatch.setattr( + integrated_worker, + "transform_input", + MagicMock(return_value=TransformInputResult(b"result_bytes")), + ) + + worker_manager._on_iteration() + + assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_with("fail", "Failed while executing.") + + +def test_transform_output_errors_handled( + setup_worker_manager, monkeypatch: pytest.MonkeyPatch +): + """Ensures that the worker manager does not crash after a failure in the + transform output pipeline stage""" + worker_manager, integrated_worker = setup_worker_manager + + def mock_transform_output(a, b): + raise ValueError("Simulated error in transform_output") + + monkeypatch.setattr(integrated_worker, "transform_output", mock_transform_output) + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, + ) + + def mock_exception_handler(exc, reply_channel, func_descriptor, reply): + return exception_handler(exc, None, func_descriptor, reply) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", + mock_exception_handler, + ) + + monkeypatch.setattr( + integrated_worker, + "load_model", + MagicMock(return_value=LoadModelResult(b"result_bytes")), + ) + monkeypatch.setattr( + integrated_worker, + "fetch_inputs", + MagicMock(return_value=FetchInputResult([b"result_bytes"])), + ) + monkeypatch.setattr( + integrated_worker, + "transform_input", + MagicMock(return_value=TransformInputResult(b"result_bytes")), + ) + monkeypatch.setattr( + integrated_worker, + "execute", + MagicMock(return_value=ExecuteResult(b"result_bytes")), + ) + + worker_manager._on_iteration() + + assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_with("fail", "Failed while transforming the output.") + + +def test_place_output_errors_handled( + setup_worker_manager, monkeypatch: pytest.MonkeyPatch +): + """Ensures that the worker manager does not crash after a failure in the + place output pipeline stage""" + worker_manager, integrated_worker = setup_worker_manager + + def mock_place_output(a, b, c): + raise ValueError("Simulated error in place_output") + + monkeypatch.setattr(integrated_worker, "place_output", mock_place_output) + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, + ) + + def mock_exception_handler(exc, reply_channel, func_descriptor, reply): + return exception_handler(exc, None, func_descriptor, reply) + + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", + mock_exception_handler, + ) + + monkeypatch.setattr( + integrated_worker, + "load_model", + MagicMock(return_value=LoadModelResult(b"result_bytes")), + ) + monkeypatch.setattr( + integrated_worker, + "fetch_inputs", + MagicMock(return_value=FetchInputResult([b"result_bytes"])), + ) + monkeypatch.setattr( + integrated_worker, + "transform_input", + MagicMock(return_value=TransformInputResult(b"result_bytes")), + ) + monkeypatch.setattr( + integrated_worker, + "execute", + MagicMock(return_value=ExecuteResult(b"result_bytes")), + ) + monkeypatch.setattr( + integrated_worker, + "transform_output", + MagicMock(return_value=TransformOutputResult(b"result", [], "c", "float32")), + ) + + worker_manager._on_iteration() + + assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_with("fail", "Failed while placing the output.") + + +def test_exception_handling_helper(): + """Ensures that the worker manager does not crash after a failure in the + execute pipeline stage""" + reply = InferenceReply() + + test_exception = ValueError("Test ValueError") + exception_handler(test_exception, None, "fetching the model", reply) + + assert reply.status_enum == "fail" + assert reply.message == "Failed while fetching the model." diff --git a/tests/dragon/test_reply_building.py b/tests/dragon/test_reply_building.py new file mode 100644 index 0000000000..f585096788 --- /dev/null +++ b/tests/dragon/test_reply_building.py @@ -0,0 +1,91 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import typing as t + +import pytest + +dragon = pytest.importorskip("dragon") + +from smartsim._core.mli.infrastructure.control.workermanager import ( + build_failure_reply, + build_reply, +) +from smartsim._core.mli.infrastructure.worker.worker import InferenceReply + +if t.TYPE_CHECKING: + from smartsim._core.mli.mli_schemas.response.response_capnp import StatusEnum + +# The tests in this file belong to the dragon group +pytestmark = pytest.mark.dragon + + +@pytest.mark.parametrize( + "status, message", + [ + pytest.param("timeout", "Worker timed out", id="timeout"), + pytest.param("fail", "Failed while executing", id="fail"), + ], +) +def test_build_failure_reply(status: "StatusEnum", message: str): + "Ensures failure replies can be built successfully" + response = build_failure_reply(status, message) + assert response.status == status + assert response.message == message + + +def test_build_failure_reply_fails(): + "Ensures ValueError is raised if a StatusEnum is not used" + with pytest.raises(ValueError) as ex: + response = build_failure_reply("not a status enum", "message") + + assert "Error assigning status to response" in ex.value.args[0] + + +@pytest.mark.parametrize( + "status, message", + [ + pytest.param("complete", "Success", id="complete"), + ], +) +def test_build_reply(status: "StatusEnum", message: str): + "Ensures replies can be built successfully" + reply = InferenceReply() + reply.status_enum = status + reply.message = message + response = build_reply(reply) + assert response.status == status + assert response.message == message + + +def test_build_reply_fails(): + "Ensures ValueError is raised if a StatusEnum is not used" + with pytest.raises(ValueError) as ex: + reply = InferenceReply() + reply.status_enum = "not a status enum" + response = build_reply(reply) + + assert "Error assigning status to response" in ex.value.args[0] diff --git a/tests/dragon/utils/channel.py b/tests/dragon/utils/channel.py new file mode 100644 index 0000000000..79bf495f83 --- /dev/null +++ b/tests/dragon/utils/channel.py @@ -0,0 +1,59 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import pathlib +import typing as t + +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class FileSystemCommChannel(CommChannelBase): + """Passes messages by writing to a file""" + + def __init__(self, key: t.Union[bytes, pathlib.Path]) -> None: + """Initialize the FileSystemCommChannel instance""" + if not isinstance(key, bytes): + super().__init__(key.as_posix().encode("utf-8")) + self._file_path = key + else: + super().__init__(key) + self._file_path = pathlib.Path(key.decode("utf-8")) + + if not self._file_path.parent.exists(): + self._file_path.parent.mkdir(parents=True) + + self._file_path.touch() + + def send(self, value: bytes) -> None: + """Send a message throuh the underlying communication channel + :param value: The value to send""" + logger.debug( + f"Channel {self.descriptor.decode('utf-8')} sending message to {self._file_path}" + ) + self._file_path.write_bytes(value) \ No newline at end of file diff --git a/tests/dragon/utils/worker.py b/tests/dragon/utils/worker.py new file mode 100644 index 0000000000..b1de280185 --- /dev/null +++ b/tests/dragon/utils/worker.py @@ -0,0 +1,128 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io +import typing as t + +import torch + +import smartsim._core.mli.infrastructure.worker.worker as mliw +import smartsim.error as sse +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class IntegratedTorchWorker(mliw.MachineLearningWorkerBase): + """A minimum implementation of a worker that executes a PyTorch model""" + + # @staticmethod + # def deserialize(request: InferenceRequest) -> t.List[t.Any]: + # # request.input_meta + # # request.raw_inputs + # return request + + @staticmethod + def load_model( + request: mliw.InferenceRequest, fetch_result: mliw.FetchModelResult + ) -> mliw.LoadModelResult: + model_bytes = fetch_result.model_bytes or request.raw_model + if not model_bytes: + raise ValueError("Unable to load model without reference object") + + model: torch.nn.Module = torch.load(io.BytesIO(model_bytes)) + result = mliw.LoadModelResult(model) + return result + + @staticmethod + def transform_input( + request: mliw.InferenceRequest, + fetch_result: mliw.FetchInputResult, + ) -> mliw.TransformInputResult: + # extra metadata for assembly can be found in request.input_meta + raw_inputs = request.raw_inputs or fetch_result.inputs + + result: t.List[torch.Tensor] = [] + # should this happen here? + # consider - fortran to c data layout + # is there an intermediate representation before really doing torch.load? + if raw_inputs: + result = [torch.load(io.BytesIO(item)) for item in raw_inputs] + + return mliw.TransformInputResult(result) + + @staticmethod + def execute( + request: mliw.InferenceRequest, + load_result: mliw.LoadModelResult, + transform_result: mliw.TransformInputResult, + ) -> mliw.ExecuteResult: + if not load_result.model: + raise sse.SmartSimError("Model must be loaded to execute") + + model = load_result.model + results = [model(tensor) for tensor in transform_result.transformed] + + execute_result = mliw.ExecuteResult(results) + return execute_result + + @staticmethod + def transform_output( + request: mliw.InferenceRequest, + execute_result: mliw.ExecuteResult, + ) -> mliw.TransformOutputResult: + # transformed = [item.clone() for item in execute_result.predictions] + # return OutputTransformResult(transformed) + + # transformed = [item.bytes() for item in execute_result.predictions] + + # OutputTransformResult.transformed SHOULD be a list of + # capnproto Tensors Or tensor descriptors accompanying bytes + + # send the original tensors... + execute_result.predictions = [t.detach() for t in execute_result.predictions] + # todo: solve sending all tensor metadata that coincisdes with each prediction + return mliw.TransformOutputResult( + execute_result.predictions, [1], "c", "float32" + ) + # return OutputTransformResult(transformed) + + # @staticmethod + # def serialize_reply( + # request: InferenceRequest, results: OutputTransformResult + # ) -> t.Any: + # # results = IntegratedTorchWorker._prepare_outputs(results.outputs) + # # return results + # return None + # # response = MessageHandler.build_response( + # # status=200, # todo: are we satisfied with 0/1 (success, fail) + # # # todo: if not detailed messages, this shouldn't be returned. + # # message="success", + # # result=results, + # # custom_attributes=None, + # # ) + # # serialized_resp = MessageHandler.serialize_response(response) + # # return serialized_resp From 437837d4ecd6c7a004cd2aca9b1de715dfabcc9a Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Thu, 11 Jul 2024 13:38:40 -0500 Subject: [PATCH 38/57] remove mli tests --- tests/mli/test_error_handling.py | 402 ------------------------------- tests/mli/test_reply_building.py | 91 ------- 2 files changed, 493 deletions(-) delete mode 100644 tests/mli/test_error_handling.py delete mode 100644 tests/mli/test_reply_building.py diff --git a/tests/mli/test_error_handling.py b/tests/mli/test_error_handling.py deleted file mode 100644 index 74b78bb863..0000000000 --- a/tests/mli/test_error_handling.py +++ /dev/null @@ -1,402 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import base64 -import pickle -from unittest.mock import MagicMock - -import pytest - -dragon = pytest.importorskip("dragon") - -import dragon.utils as du -from dragon.channels import Channel -from dragon.data.ddict.ddict import DDict -from dragon.fli import FLInterface - -from smartsim._core.mli.infrastructure.control.workermanager import ( - WorkerManager, - exception_handler, -) -from smartsim._core.mli.infrastructure.environmentloader import EnvironmentConfigLoader -from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import ( - DragonFeatureStore, -) -from smartsim._core.mli.infrastructure.worker.worker import ( - ExecuteResult, - FetchInputResult, - InferenceReply, - LoadModelResult, - TransformInputResult, - TransformOutputResult, -) -from smartsim._core.mli.message_handler import MessageHandler - -from .channel import FileSystemCommChannel -from .worker import IntegratedTorchWorker - -# The tests in this file belong to the dragon group -pytestmark = pytest.mark.dragon - - -@pytest.fixture -def setup_worker_manager(test_dir, monkeypatch): - integrated_worker = IntegratedTorchWorker() - - chan = Channel.make_process_local() - queue = FLInterface(main_ch=chan) - monkeypatch.setenv("SSQueue", du.B64.bytes_to_str(queue.serialize())) - storage = DDict() - feature_store = DragonFeatureStore(storage) - monkeypatch.setenv( - "SSFeatureStore", base64.b64encode(pickle.dumps(feature_store)).decode("utf-8") - ) - - worker_manager = WorkerManager( - EnvironmentConfigLoader(), - integrated_worker, - as_service=False, - cooldown=3, - comm_channel_type=FileSystemCommChannel, - ) - - tensor_key = MessageHandler.build_tensor_key("key") - model = MessageHandler.build_model(b"model", "model name", "v 0.0.1") - request = MessageHandler.build_request( - b"channel", model, [tensor_key], [tensor_key], [], None - ) - ser_request = MessageHandler.serialize_request(request) - new_sender = worker_manager._task_queue.sendh(use_main_as_stream_channel=True) - new_sender.send_bytes(ser_request) - - return worker_manager, integrated_worker - - -def test_fetch_model_errors_handled( - setup_worker_manager, monkeypatch: pytest.MonkeyPatch -): - """Ensures that the worker manager does not crash after a failure in the - fetch model pipeline stage""" - worker_manager, integrated_worker = setup_worker_manager - - def mock_fetch_model(a, b): - raise ValueError("Simulated error in fetch_model") - - monkeypatch.setattr(integrated_worker, "fetch_model", mock_fetch_model) - mock_reply_fn = MagicMock() - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", - mock_reply_fn, - ) - - def mock_exception_handler(exc, reply_channel, func_descriptor, reply): - return exception_handler(exc, None, func_descriptor, reply) - - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", - mock_exception_handler, - ) - - worker_manager._on_iteration() - - assert mock_reply_fn.called_once() - mock_reply_fn.assert_called_with("fail", "Failed while fetching the model.") - - -def test_load_model_errors_handled( - setup_worker_manager, monkeypatch: pytest.MonkeyPatch -): - """Ensures that the worker manager does not crash after a failure in the - load model pipeline stage""" - worker_manager, integrated_worker = setup_worker_manager - - def mock_load_model(a, b): - raise ValueError("Simulated error in load_model") - - monkeypatch.setattr(integrated_worker, "load_model", mock_load_model) - mock_reply_fn = MagicMock() - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", - mock_reply_fn, - ) - - def mock_exception_handler(exc, reply_channel, func_descriptor, reply): - return exception_handler(exc, None, func_descriptor, reply) - - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", - mock_exception_handler, - ) - - worker_manager._on_iteration() - - assert mock_reply_fn.called_once() - mock_reply_fn.assert_called_with("fail", "Failed while loading the model.") - - -def test_fetch_inputs_errors_handled( - setup_worker_manager, monkeypatch: pytest.MonkeyPatch -): - """Ensures that the worker manager does not crash after a failure in the - fetch inputs pipeline stage""" - worker_manager, integrated_worker = setup_worker_manager - - def mock_fetch_inputs(a, b): - raise ValueError("Simulated error in fetch_inputs") - - monkeypatch.setattr(integrated_worker, "fetch_inputs", mock_fetch_inputs) - mock_reply_fn = MagicMock() - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", - mock_reply_fn, - ) - - def mock_exception_handler(exc, reply_channel, func_descriptor, reply): - return exception_handler(exc, None, func_descriptor, reply) - - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", - mock_exception_handler, - ) - - monkeypatch.setattr( - integrated_worker, - "load_model", - MagicMock(return_value=LoadModelResult(b"result_bytes")), - ) - - worker_manager._on_iteration() - - assert mock_reply_fn.called_once() - mock_reply_fn.assert_called_with("fail", "Failed while fetching the inputs.") - - -def test_transform_input_errors_handled( - setup_worker_manager, monkeypatch: pytest.MonkeyPatch -): - """Ensures that the worker manager does not crash after a failure in the - transform input pipeline stage""" - worker_manager, integrated_worker = setup_worker_manager - - def mock_transform_input(a, b): - raise ValueError("Simulated error in transform_input") - - monkeypatch.setattr(integrated_worker, "transform_input", mock_transform_input) - mock_reply_fn = MagicMock() - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", - mock_reply_fn, - ) - - def mock_exception_handler(exc, reply_channel, func_descriptor, reply): - return exception_handler(exc, None, func_descriptor, reply) - - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", - mock_exception_handler, - ) - - monkeypatch.setattr( - integrated_worker, - "load_model", - MagicMock(return_value=LoadModelResult(b"result_bytes")), - ) - monkeypatch.setattr( - integrated_worker, - "fetch_inputs", - MagicMock(return_value=FetchInputResult([b"result_bytes"])), - ) - - worker_manager._on_iteration() - - assert mock_reply_fn.called_once() - mock_reply_fn.assert_called_with("fail", "Failed while transforming the input.") - - -def test_execute_errors_handled(setup_worker_manager, monkeypatch: pytest.MonkeyPatch): - """Ensures that the worker manager does not crash after a failure in the - execute pipeline stage""" - worker_manager, integrated_worker = setup_worker_manager - - def mock_execute(a, b, c): - raise ValueError("Simulated error in execute") - - monkeypatch.setattr(integrated_worker, "execute", mock_execute) - mock_reply_fn = MagicMock() - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", - mock_reply_fn, - ) - - def mock_exception_handler(exc, reply_channel, func_descriptor, reply): - return exception_handler(exc, None, func_descriptor, reply) - - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", - mock_exception_handler, - ) - - monkeypatch.setattr( - integrated_worker, - "load_model", - MagicMock(return_value=LoadModelResult(b"result_bytes")), - ) - monkeypatch.setattr( - integrated_worker, - "fetch_inputs", - MagicMock(return_value=FetchInputResult([b"result_bytes"])), - ) - monkeypatch.setattr( - integrated_worker, - "transform_input", - MagicMock(return_value=TransformInputResult(b"result_bytes")), - ) - - worker_manager._on_iteration() - - assert mock_reply_fn.called_once() - mock_reply_fn.assert_called_with("fail", "Failed while executing.") - - -def test_transform_output_errors_handled( - setup_worker_manager, monkeypatch: pytest.MonkeyPatch -): - """Ensures that the worker manager does not crash after a failure in the - transform output pipeline stage""" - worker_manager, integrated_worker = setup_worker_manager - - def mock_transform_output(a, b): - raise ValueError("Simulated error in transform_output") - - monkeypatch.setattr(integrated_worker, "transform_output", mock_transform_output) - mock_reply_fn = MagicMock() - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", - mock_reply_fn, - ) - - def mock_exception_handler(exc, reply_channel, func_descriptor, reply): - return exception_handler(exc, None, func_descriptor, reply) - - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", - mock_exception_handler, - ) - - monkeypatch.setattr( - integrated_worker, - "load_model", - MagicMock(return_value=LoadModelResult(b"result_bytes")), - ) - monkeypatch.setattr( - integrated_worker, - "fetch_inputs", - MagicMock(return_value=FetchInputResult([b"result_bytes"])), - ) - monkeypatch.setattr( - integrated_worker, - "transform_input", - MagicMock(return_value=TransformInputResult(b"result_bytes")), - ) - monkeypatch.setattr( - integrated_worker, - "execute", - MagicMock(return_value=ExecuteResult(b"result_bytes")), - ) - - worker_manager._on_iteration() - - assert mock_reply_fn.called_once() - mock_reply_fn.assert_called_with("fail", "Failed while transforming the output.") - - -def test_place_output_errors_handled( - setup_worker_manager, monkeypatch: pytest.MonkeyPatch -): - """Ensures that the worker manager does not crash after a failure in the - place output pipeline stage""" - worker_manager, integrated_worker = setup_worker_manager - - def mock_place_output(a, b, c): - raise ValueError("Simulated error in place_output") - - monkeypatch.setattr(integrated_worker, "place_output", mock_place_output) - mock_reply_fn = MagicMock() - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", - mock_reply_fn, - ) - - def mock_exception_handler(exc, reply_channel, func_descriptor, reply): - return exception_handler(exc, None, func_descriptor, reply) - - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", - mock_exception_handler, - ) - - monkeypatch.setattr( - integrated_worker, - "load_model", - MagicMock(return_value=LoadModelResult(b"result_bytes")), - ) - monkeypatch.setattr( - integrated_worker, - "fetch_inputs", - MagicMock(return_value=FetchInputResult([b"result_bytes"])), - ) - monkeypatch.setattr( - integrated_worker, - "transform_input", - MagicMock(return_value=TransformInputResult(b"result_bytes")), - ) - monkeypatch.setattr( - integrated_worker, - "execute", - MagicMock(return_value=ExecuteResult(b"result_bytes")), - ) - monkeypatch.setattr( - integrated_worker, - "transform_output", - MagicMock(return_value=TransformOutputResult(b"result", [], "c", "float32")), - ) - - worker_manager._on_iteration() - - assert mock_reply_fn.called_once() - mock_reply_fn.assert_called_with("fail", "Failed while placing the output.") - - -def test_exception_handling_helper(): - """Ensures that the worker manager does not crash after a failure in the - execute pipeline stage""" - reply = InferenceReply() - - test_exception = ValueError("Test ValueError") - exception_handler(test_exception, None, "fetching the model", reply) - - assert reply.status_enum == "fail" - assert reply.message == "Failed while fetching the model." diff --git a/tests/mli/test_reply_building.py b/tests/mli/test_reply_building.py deleted file mode 100644 index f585096788..0000000000 --- a/tests/mli/test_reply_building.py +++ /dev/null @@ -1,91 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import typing as t - -import pytest - -dragon = pytest.importorskip("dragon") - -from smartsim._core.mli.infrastructure.control.workermanager import ( - build_failure_reply, - build_reply, -) -from smartsim._core.mli.infrastructure.worker.worker import InferenceReply - -if t.TYPE_CHECKING: - from smartsim._core.mli.mli_schemas.response.response_capnp import StatusEnum - -# The tests in this file belong to the dragon group -pytestmark = pytest.mark.dragon - - -@pytest.mark.parametrize( - "status, message", - [ - pytest.param("timeout", "Worker timed out", id="timeout"), - pytest.param("fail", "Failed while executing", id="fail"), - ], -) -def test_build_failure_reply(status: "StatusEnum", message: str): - "Ensures failure replies can be built successfully" - response = build_failure_reply(status, message) - assert response.status == status - assert response.message == message - - -def test_build_failure_reply_fails(): - "Ensures ValueError is raised if a StatusEnum is not used" - with pytest.raises(ValueError) as ex: - response = build_failure_reply("not a status enum", "message") - - assert "Error assigning status to response" in ex.value.args[0] - - -@pytest.mark.parametrize( - "status, message", - [ - pytest.param("complete", "Success", id="complete"), - ], -) -def test_build_reply(status: "StatusEnum", message: str): - "Ensures replies can be built successfully" - reply = InferenceReply() - reply.status_enum = status - reply.message = message - response = build_reply(reply) - assert response.status == status - assert response.message == message - - -def test_build_reply_fails(): - "Ensures ValueError is raised if a StatusEnum is not used" - with pytest.raises(ValueError) as ex: - reply = InferenceReply() - reply.status_enum = "not a status enum" - response = build_reply(reply) - - assert "Error assigning status to response" in ex.value.args[0] From 901fcefb5d89f288bc239aa56a0dbe4f841237d8 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Thu, 11 Jul 2024 13:49:22 -0500 Subject: [PATCH 39/57] style --- tests/dragon/utils/channel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/dragon/utils/channel.py b/tests/dragon/utils/channel.py index 79bf495f83..4bc2014ea3 100644 --- a/tests/dragon/utils/channel.py +++ b/tests/dragon/utils/channel.py @@ -56,4 +56,4 @@ def send(self, value: bytes) -> None: logger.debug( f"Channel {self.descriptor.decode('utf-8')} sending message to {self._file_path}" ) - self._file_path.write_bytes(value) \ No newline at end of file + self._file_path.write_bytes(value) From 4bdcc59b6be98a3af300b74b05c487a446abffba Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Thu, 11 Jul 2024 16:04:22 -0500 Subject: [PATCH 40/57] parametrize mock tests --- tests/dragon/test_error_handling.py | 345 +++++++--------------------- 1 file changed, 80 insertions(+), 265 deletions(-) diff --git a/tests/dragon/test_error_handling.py b/tests/dragon/test_error_handling.py index 799e1734ba..4ced3e0ff2 100644 --- a/tests/dragon/test_error_handling.py +++ b/tests/dragon/test_error_handling.py @@ -63,7 +63,7 @@ @pytest.fixture -def setup_worker_manager(test_dir, monkeypatch): +def setup_worker_manager(test_dir, monkeypatch: pytest.MonkeyPatch): integrated_worker = IntegratedTorchWorker() chan = Channel.make_process_local() @@ -86,7 +86,7 @@ def setup_worker_manager(test_dir, monkeypatch): tensor_key = MessageHandler.build_tensor_key("key") model = MessageHandler.build_model(b"model", "model name", "v 0.0.1") request = MessageHandler.build_request( - b"channel", model, [tensor_key], [tensor_key], [], None + test_dir, model, [tensor_key], [tensor_key], [], None ) ser_request = MessageHandler.serialize_request(request) new_sender = worker_manager._task_queue.sendh(use_main_as_stream_channel=True) @@ -95,156 +95,11 @@ def setup_worker_manager(test_dir, monkeypatch): return worker_manager, integrated_worker -def test_fetch_model_errors_handled( - setup_worker_manager, monkeypatch: pytest.MonkeyPatch -): - """Ensures that the worker manager does not crash after a failure in the - fetch model pipeline stage""" - worker_manager, integrated_worker = setup_worker_manager - - def mock_fetch_model(a, b): - raise ValueError("Simulated error in fetch_model") - - monkeypatch.setattr(integrated_worker, "fetch_model", mock_fetch_model) - mock_reply_fn = MagicMock() - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", - mock_reply_fn, - ) - - def mock_exception_handler(exc, reply_channel, func_descriptor, reply): - return exception_handler(exc, None, func_descriptor, reply) - - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", - mock_exception_handler, - ) - - worker_manager._on_iteration() - - assert mock_reply_fn.called_once() - mock_reply_fn.assert_called_with("fail", "Failed while fetching the model.") - - -def test_load_model_errors_handled( - setup_worker_manager, monkeypatch: pytest.MonkeyPatch -): - """Ensures that the worker manager does not crash after a failure in the - load model pipeline stage""" - worker_manager, integrated_worker = setup_worker_manager - - def mock_load_model(a, b): - raise ValueError("Simulated error in load_model") - - monkeypatch.setattr(integrated_worker, "load_model", mock_load_model) - mock_reply_fn = MagicMock() - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", - mock_reply_fn, - ) - - def mock_exception_handler(exc, reply_channel, func_descriptor, reply): - return exception_handler(exc, None, func_descriptor, reply) - - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", - mock_exception_handler, - ) - - worker_manager._on_iteration() - - assert mock_reply_fn.called_once() - mock_reply_fn.assert_called_with("fail", "Failed while loading the model.") - - -def test_fetch_inputs_errors_handled( - setup_worker_manager, monkeypatch: pytest.MonkeyPatch -): - """Ensures that the worker manager does not crash after a failure in the - fetch inputs pipeline stage""" - worker_manager, integrated_worker = setup_worker_manager - - def mock_fetch_inputs(a, b): - raise ValueError("Simulated error in fetch_inputs") - - monkeypatch.setattr(integrated_worker, "fetch_inputs", mock_fetch_inputs) - mock_reply_fn = MagicMock() - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", - mock_reply_fn, - ) - - def mock_exception_handler(exc, reply_channel, func_descriptor, reply): - return exception_handler(exc, None, func_descriptor, reply) - - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", - mock_exception_handler, - ) - - monkeypatch.setattr( - integrated_worker, - "load_model", - MagicMock(return_value=LoadModelResult(b"result_bytes")), - ) - - worker_manager._on_iteration() - - assert mock_reply_fn.called_once() - mock_reply_fn.assert_called_with("fail", "Failed while fetching the inputs.") - - -def test_transform_input_errors_handled( - setup_worker_manager, monkeypatch: pytest.MonkeyPatch -): - """Ensures that the worker manager does not crash after a failure in the - transform input pipeline stage""" - worker_manager, integrated_worker = setup_worker_manager - - def mock_transform_input(a, b): - raise ValueError("Simulated error in transform_input") - - monkeypatch.setattr(integrated_worker, "transform_input", mock_transform_input) - mock_reply_fn = MagicMock() - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", - mock_reply_fn, - ) - - def mock_exception_handler(exc, reply_channel, func_descriptor, reply): - return exception_handler(exc, None, func_descriptor, reply) - - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", - mock_exception_handler, - ) - - monkeypatch.setattr( - integrated_worker, - "load_model", - MagicMock(return_value=LoadModelResult(b"result_bytes")), - ) - monkeypatch.setattr( - integrated_worker, - "fetch_inputs", - MagicMock(return_value=FetchInputResult([b"result_bytes"])), - ) - - worker_manager._on_iteration() - - assert mock_reply_fn.called_once() - mock_reply_fn.assert_called_with("fail", "Failed while transforming the input.") - - -def test_execute_errors_handled(setup_worker_manager, monkeypatch: pytest.MonkeyPatch): - """Ensures that the worker manager does not crash after a failure in the - execute pipeline stage""" - worker_manager, integrated_worker = setup_worker_manager - - def mock_execute(a, b, c): - raise ValueError("Simulated error in execute") +def mock_pipeline_stage(monkeypatch: pytest.MonkeyPatch, integrated_worker, stage): + def mock_stage(*args, **kwargs): + raise ValueError(f"Simulated error in {stage}") - monkeypatch.setattr(integrated_worker, "execute", mock_execute) + monkeypatch.setattr(integrated_worker, stage, mock_stage) mock_reply_fn = MagicMock() monkeypatch.setattr( "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", @@ -258,136 +113,96 @@ def mock_exception_handler(exc, reply_channel, func_descriptor, reply): "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", mock_exception_handler, ) - - monkeypatch.setattr( - integrated_worker, - "load_model", - MagicMock(return_value=LoadModelResult(b"result_bytes")), - ) - monkeypatch.setattr( - integrated_worker, - "fetch_inputs", - MagicMock(return_value=FetchInputResult([b"result_bytes"])), - ) - monkeypatch.setattr( - integrated_worker, - "transform_input", - MagicMock(return_value=TransformInputResult(b"result_bytes")), - ) - - worker_manager._on_iteration() - - assert mock_reply_fn.called_once() - mock_reply_fn.assert_called_with("fail", "Failed while executing.") - - -def test_transform_output_errors_handled( - setup_worker_manager, monkeypatch: pytest.MonkeyPatch + return mock_reply_fn + + +@pytest.mark.parametrize( + "stage, error_message", + [ + pytest.param( + "fetch_model", "Failed while fetching the model.", id="fetch model" + ), + pytest.param("load_model", "Failed while loading the model.", id="load model"), + pytest.param( + "fetch_inputs", "Failed while fetching the inputs.", id="fetch inputs" + ), + pytest.param( + "transform_input", + "Failed while transforming the input.", + id="transform inputs", + ), + pytest.param("execute", "Failed while executing.", id="execute"), + pytest.param( + "transform_output", + "Failed while transforming the output.", + id="transform output", + ), + pytest.param( + "place_output", "Failed while placing the output.", id="place output" + ), + ], +) +def test_pipeline_stage_errors_handled( + setup_worker_manager, + monkeypatch: pytest.MonkeyPatch, + stage: str, + error_message: str, ): - """Ensures that the worker manager does not crash after a failure in the - transform output pipeline stage""" + """Ensures that the worker manager does not crash after a failure in various pipeline stages""" worker_manager, integrated_worker = setup_worker_manager - def mock_transform_output(a, b): - raise ValueError("Simulated error in transform_output") - - monkeypatch.setattr(integrated_worker, "transform_output", mock_transform_output) - mock_reply_fn = MagicMock() - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", - mock_reply_fn, - ) - - def mock_exception_handler(exc, reply_channel, func_descriptor, reply): - return exception_handler(exc, None, func_descriptor, reply) - - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", - mock_exception_handler, - ) - - monkeypatch.setattr( - integrated_worker, + mock_reply_fn = mock_pipeline_stage(monkeypatch, integrated_worker, stage) + + if stage not in ["fetch_model", "load_model"]: + monkeypatch.setattr( + integrated_worker, + "load_model", + MagicMock(return_value=LoadModelResult(b"result_bytes")), + ) + if stage not in ["fetch_model", "load_model", "fetch_inputs"]: + monkeypatch.setattr( + integrated_worker, + "fetch_inputs", + MagicMock(return_value=FetchInputResult([b"result_bytes"])), + ) + if stage not in ["fetch_model", "load_model", "fetch_inputs", "transform_input"]: + monkeypatch.setattr( + integrated_worker, + "transform_input", + MagicMock(return_value=TransformInputResult(b"result_bytes")), + ) + if stage not in [ + "fetch_model", "load_model", - MagicMock(return_value=LoadModelResult(b"result_bytes")), - ) - monkeypatch.setattr( - integrated_worker, "fetch_inputs", - MagicMock(return_value=FetchInputResult([b"result_bytes"])), - ) - monkeypatch.setattr( - integrated_worker, "transform_input", - MagicMock(return_value=TransformInputResult(b"result_bytes")), - ) - monkeypatch.setattr( - integrated_worker, "execute", - MagicMock(return_value=ExecuteResult(b"result_bytes")), - ) - - worker_manager._on_iteration() - - assert mock_reply_fn.called_once() - mock_reply_fn.assert_called_with("fail", "Failed while transforming the output.") - - -def test_place_output_errors_handled( - setup_worker_manager, monkeypatch: pytest.MonkeyPatch -): - """Ensures that the worker manager does not crash after a failure in the - place output pipeline stage""" - worker_manager, integrated_worker = setup_worker_manager - - def mock_place_output(a, b, c): - raise ValueError("Simulated error in place_output") - - monkeypatch.setattr(integrated_worker, "place_output", mock_place_output) - mock_reply_fn = MagicMock() - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", - mock_reply_fn, - ) - - def mock_exception_handler(exc, reply_channel, func_descriptor, reply): - return exception_handler(exc, None, func_descriptor, reply) - - monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", - mock_exception_handler, - ) - - monkeypatch.setattr( - integrated_worker, + ]: + monkeypatch.setattr( + integrated_worker, + "execute", + MagicMock(return_value=ExecuteResult(b"result_bytes")), + ) + if stage not in [ + "fetch_model", "load_model", - MagicMock(return_value=LoadModelResult(b"result_bytes")), - ) - monkeypatch.setattr( - integrated_worker, "fetch_inputs", - MagicMock(return_value=FetchInputResult([b"result_bytes"])), - ) - monkeypatch.setattr( - integrated_worker, "transform_input", - MagicMock(return_value=TransformInputResult(b"result_bytes")), - ) - monkeypatch.setattr( - integrated_worker, "execute", - MagicMock(return_value=ExecuteResult(b"result_bytes")), - ) - monkeypatch.setattr( - integrated_worker, "transform_output", - MagicMock(return_value=TransformOutputResult(b"result", [], "c", "float32")), - ) + ]: + monkeypatch.setattr( + integrated_worker, + "transform_output", + MagicMock( + return_value=TransformOutputResult(b"result", [], "c", "float32") + ), + ) worker_manager._on_iteration() assert mock_reply_fn.called_once() - mock_reply_fn.assert_called_with("fail", "Failed while placing the output.") + mock_reply_fn.assert_called_with("fail", error_message) def test_exception_handling_helper(): From 13ecd13765391734406f7ff45eb0123b7ee8c376 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Thu, 11 Jul 2024 19:33:44 -0500 Subject: [PATCH 41/57] ignore tests --- Makefile | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index aaf1736258..3ab83da892 100644 --- a/Makefile +++ b/Makefile @@ -169,17 +169,17 @@ test: # help: test-verbose - Run all tests verbosely .PHONY: test-verbose test-verbose: - @python -m pytest -vv --ignore=tests/full_wlm/ + @python -m pytest -vv --ignore=tests/full_wlm/ --ignore=tests/dragon # help: test-debug - Run all tests with debug output .PHONY: test-debug test-debug: - @SMARTSIM_LOG_LEVEL=developer python -m pytest -s -o log_cli=true -vv --ignore=tests/full_wlm/ + @SMARTSIM_LOG_LEVEL=developer python -m pytest -s -o log_cli=true -vv --ignore=tests/full_wlm/ --ignore=tests/dragon # help: test-cov - Run all tests with coverage .PHONY: test-cov test-cov: - @python -m pytest -vv --cov=./smartsim --cov-config=${COV_FILE} --ignore=tests/full_wlm/ + @python -m pytest -vv --cov=./smartsim --cov-config=${COV_FILE} --ignore=tests/full_wlm/ --ignore=tests/dragon # help: test-full - Run all WLM tests with Python coverage (full test suite) From 0af26a1c83069734a3211a5f05a5bfaf2e14d7d6 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Tue, 16 Jul 2024 11:59:33 -0500 Subject: [PATCH 42/57] merge --- doc/changelog.md | 1 + ex/high_throughput_inference/mli_driver.py | 50 +++++ ex/high_throughput_inference/mock_app.py | 195 ++++++++++++++++++ .../mock_app_redis.py | 88 ++++++++ ex/high_throughput_inference/redis_driver.py | 65 ++++++ .../standalone_workermanager.py | 96 +++++++++ smartsim/_core/entrypoints/service.py | 20 +- .../_core/launcher/dragon/dragonBackend.py | 27 ++- smartsim/_core/mli/comm/channel/channel.py | 7 +- .../_core/mli/comm/channel/dragonchannel.py | 22 +- smartsim/_core/mli/comm/channel/dragonfli.py | 69 +++++++ .../infrastructure/control/workermanager.py | 171 +++++++++++---- .../mli/infrastructure/environmentloader.py | 16 +- .../storage/dragonfeaturestore.py | 21 +- .../infrastructure/storage/featurestore.py | 5 +- .../mli/infrastructure/worker/torch_worker.py | 119 +++++++++++ .../_core/mli/infrastructure/worker/worker.py | 73 +++---- smartsim/_core/mli/message_handler.py | 10 +- tests/dragon/test_environment_loader.py | 7 +- tests/dragon/test_error_handling.py | 5 +- tests/dragon/utils/channel.py | 5 + tests/mli/test_torch_worker.py | 173 ++++++++++++++++ tests/mli/test_worker_manager.py | 3 +- tests/test_dragon_backend.py | 10 + 24 files changed, 1141 insertions(+), 117 deletions(-) create mode 100644 ex/high_throughput_inference/mli_driver.py create mode 100644 ex/high_throughput_inference/mock_app.py create mode 100644 ex/high_throughput_inference/mock_app_redis.py create mode 100644 ex/high_throughput_inference/redis_driver.py create mode 100644 ex/high_throughput_inference/standalone_workermanager.py create mode 100644 smartsim/_core/mli/comm/channel/dragonfli.py create mode 100644 smartsim/_core/mli/infrastructure/worker/torch_worker.py create mode 100644 tests/mli/test_torch_worker.py diff --git a/doc/changelog.md b/doc/changelog.md index 83b3ce6b71..28de6e8f95 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -13,6 +13,7 @@ Jump to: Description +- Add TorchWorker first implementation and mock inference app example - Add error handling in Worker Manager pipeline - Add EnvironmentConfigLoader for ML Worker Manager - Add Model schema with model metadata included diff --git a/ex/high_throughput_inference/mli_driver.py b/ex/high_throughput_inference/mli_driver.py new file mode 100644 index 0000000000..6da559aa6f --- /dev/null +++ b/ex/high_throughput_inference/mli_driver.py @@ -0,0 +1,50 @@ + + +import os +import base64 +import cloudpickle +import sys +from smartsim import Experiment +from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker +from smartsim.status import TERMINAL_STATUSES +import time +import typing as t + +device = "gpu" +filedir = os.path.dirname(__file__) +worker_manager_script_name = os.path.join(filedir, "standalone_workermanager.py") +app_script_name = os.path.join(filedir, "mock_app.py") +model_name = os.path.join(filedir, f"resnet50.{device.upper()}.pt") + +transport: t.Literal["hsta", "tcp"] = "hsta" + +os.environ["SMARTSIM_DRAGON_TRANSPORT"] = transport + +exp_path = os.path.join(filedir, f"MLI_proto_{transport.upper()}") +os.makedirs(exp_path, exist_ok=True) +exp = Experiment("MLI_proto", launcher="dragon", exp_path=exp_path) + +torch_worker_str = base64.b64encode(cloudpickle.dumps(TorchWorker)).decode("ascii") + +worker_manager_rs = exp.create_run_settings(sys.executable, [worker_manager_script_name, "--device", device, "--worker_class", torch_worker_str]) +worker_manager = exp.create_model("worker_manager", run_settings=worker_manager_rs) +worker_manager.attach_generator_files(to_copy=[worker_manager_script_name]) + +app_rs = exp.create_run_settings(sys.executable, exe_args = [app_script_name, "--device", device]) +app = exp.create_model("app", run_settings=app_rs) +app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name]) + + +exp.generate(worker_manager, app, overwrite=True) +exp.start(worker_manager, app, block=False) + +while True: + if exp.get_status(app)[0] in TERMINAL_STATUSES: + exp.stop(worker_manager) + break + if exp.get_status(worker_manager)[0] in TERMINAL_STATUSES: + exp.stop(app) + break + time.sleep(5) + +print("Exiting.") \ No newline at end of file diff --git a/ex/high_throughput_inference/mock_app.py b/ex/high_throughput_inference/mock_app.py new file mode 100644 index 0000000000..45246db2e5 --- /dev/null +++ b/ex/high_throughput_inference/mock_app.py @@ -0,0 +1,195 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# isort: off +import dragon +from dragon import fli +from dragon.channels import Channel +import dragon.channels +from dragon.data.ddict.ddict import DDict +from dragon.globalservices.api_setup import connect_to_infrastructure +from dragon.utils import b64decode, b64encode + +# isort: on + +import argparse +import io +import numpy +import os +import time +import torch +import numbers + +from collections import OrderedDict +from smartsim._core.mli.message_handler import MessageHandler +from smartsim.log import get_logger + +logger = get_logger("App") + +class ProtoClient: + def __init__(self, timing_on: bool): + connect_to_infrastructure() + ddict_str = os.environ["SS_DRG_DDICT"] + self._ddict = DDict.attach(ddict_str) + to_worker_fli_str = None + while to_worker_fli_str is None: + try: + to_worker_fli_str = self._ddict["to_worker_fli"] + self._to_worker_fli = fli.FLInterface.attach(to_worker_fli_str) + except KeyError: + time.sleep(1) + self._from_worker_ch = Channel.make_process_local() + self._from_worker_ch_serialized = self._from_worker_ch.serialize() + self._to_worker_ch = Channel.make_process_local() + + self._start = None + self._interm = None + self._timings: OrderedDict[str, list[numbers.Number]] = OrderedDict() + self._timing_on = timing_on + + def _add_label_to_timings(self, label: str): + if label not in self._timings: + self._timings[label] = [] + + @staticmethod + def _format_number(number: numbers.Number): + return f"{number:0.4e}" + + def start_timings(self, batch_size: int): + if self._timing_on: + self._add_label_to_timings("batch_size") + self._timings["batch_size"].append(batch_size) + self._start = time.perf_counter() + self._interm = time.perf_counter() + + def end_timings(self): + if self._timing_on: + self._add_label_to_timings("total_time") + self._timings["total_time"].append(self._format_number(time.perf_counter()-self._start)) + + def measure_time(self, label: str): + if self._timing_on: + self._add_label_to_timings(label) + self._timings[label].append(self._format_number(time.perf_counter()-self._interm)) + self._interm = time.perf_counter() + + def print_timings(self, to_file: bool = False): + print(" ".join(self._timings.keys())) + value_array = numpy.array([value for value in self._timings.values()], dtype=float) + value_array = numpy.transpose(value_array) + for i in range(value_array.shape[0]): + print(" ".join(self._format_number(value) for value in value_array[i])) + if to_file: + numpy.save("timings.npy", value_array) + numpy.savetxt("timings.txt", value_array) + + + def run_model(self, model: bytes | str, batch: torch.Tensor): + self.start_timings(batch.shape[0]) + built_tensor = MessageHandler.build_tensor( + batch.numpy(), "c", "float32", list(batch.shape)) + self.measure_time("build_tensor") + built_model = None + if isinstance(model, str): + model_arg = MessageHandler.build_model_key(model) + else: + model_arg = MessageHandler.build_model(model, "resnet-50", "1.0") + request = MessageHandler.build_request( + reply_channel=self._from_worker_ch_serialized, + model= model_arg, + inputs=[built_tensor], + outputs=[], + output_descriptors=[], + custom_attributes=None, + ) + self.measure_time("build_request") + request_bytes = MessageHandler.serialize_request(request) + self.measure_time("serialize_request") + with self._to_worker_fli.sendh(timeout=None, stream_channel=self._to_worker_ch) as to_sendh: + to_sendh.send_bytes(request_bytes) + logger.info(f"Message size: {len(request_bytes)} bytes") + + self.measure_time("send") + with self._from_worker_ch.recvh(timeout=None) as from_recvh: + resp = from_recvh.recv_bytes(timeout=None) + self.measure_time("receive") + response = MessageHandler.deserialize_response(resp) + self.measure_time("deserialize_response") + result = torch.from_numpy( + numpy.frombuffer( + response.result.data[0].blob, + dtype=str(response.result.data[0].tensorDescriptor.dataType), + ) + ) + self.measure_time("deserialize_tensor") + + self.end_timings() + return result + + def set_model(self, key: str, model: bytes): + self._ddict[key] = model + + +class ResNetWrapper(): + def __init__(self, name: str, model: str): + self._model = torch.jit.load(model) + self._name = name + buffer = io.BytesIO() + scripted = torch.jit.trace(self._model, self.get_batch()) + torch.jit.save(scripted, buffer) + self._serialized_model = buffer.getvalue() + + def get_batch(self, batch_size: int=32): + return torch.randn((batch_size, 3, 224, 224), dtype=torch.float32) + + @property + def model(self): + return self._serialized_model + + @property + def name(self): + return self._name + +if __name__ == "__main__": + + parser = argparse.ArgumentParser("Mock application") + parser.add_argument("--device", default="cpu") + args = parser.parse_args() + + resnet = ResNetWrapper("resnet50", f"resnet50.{args.device.upper()}.pt") + + client = ProtoClient(timing_on=True) + client.set_model(resnet.name, resnet.model) + + total_iterations = 100 + + for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: + logger.info(f"Batch size: {batch_size}") + for iteration_number in range(total_iterations + int(batch_size==1)): + logger.info(f"Iteration: {iteration_number}") + client.run_model(resnet.name, resnet.get_batch(batch_size)) + + client.print_timings(to_file=True) \ No newline at end of file diff --git a/ex/high_throughput_inference/mock_app_redis.py b/ex/high_throughput_inference/mock_app_redis.py new file mode 100644 index 0000000000..c56b4fb8b4 --- /dev/null +++ b/ex/high_throughput_inference/mock_app_redis.py @@ -0,0 +1,88 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import io +import numpy +import time +import torch +from smartsim.log import get_logger +from smartredis import Client + +logger = get_logger("App") + +class ResNetWrapper(): + def __init__(self, name: str, model: str): + self._model = torch.jit.load(model) + self._name = name + buffer = io.BytesIO() + scripted = torch.jit.trace(self._model, self.get_batch()) + torch.jit.save(scripted, buffer) + self._serialized_model = buffer.getvalue() + + def get_batch(self, batch_size: int=32): + return torch.randn((batch_size, 3, 224, 224), dtype=torch.float32) + + @property + def model(self): + return self._serialized_model + + @property + def name(self): + return self._name + +if __name__ == "__main__": + + parser = argparse.ArgumentParser("Mock application") + parser.add_argument("--device", default="cpu") + args = parser.parse_args() + + resnet = ResNetWrapper("resnet50", f"resnet50.{args.device.upper()}.pt") + + client = Client(cluster=False, address=None) + client.set_model(resnet.name, resnet.model, backend='TORCH', device=args.device.upper()) + + total_iterations = 100 + timings=[] + for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: + logger.info(f"Batch size: {batch_size}") + for iteration_number in range(total_iterations + int(batch_size==1)): + timing = [batch_size] + logger.info(f"Iteration: {iteration_number}") + start = time.perf_counter() + client.put_tensor(name="batch", data=resnet.get_batch(batch_size).numpy()) + client.run_model(name=resnet.name, inputs=["batch"], outputs=["result"]) + result = client.get_tensor(name="result") + end = time.perf_counter() + timing.append(end-start) + timings.append(timing) + + + + timings_np = numpy.asarray(timings) + numpy.save("timings.npy", timings_np) + for timing in timings: + print(" ".join(str(t) for t in timing)) diff --git a/ex/high_throughput_inference/redis_driver.py b/ex/high_throughput_inference/redis_driver.py new file mode 100644 index 0000000000..ceddba4ef7 --- /dev/null +++ b/ex/high_throughput_inference/redis_driver.py @@ -0,0 +1,65 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import os +import sys +from smartsim import Experiment +from smartsim.status import TERMINAL_STATUSES +import time +import typing as t + +device = "gpu" +filedir = os.path.dirname(__file__) +app_script_name = os.path.join(filedir, "mock_app_redis.py") +model_name = os.path.join(filedir, f"resnet50.{device.upper()}.pt") + + +exp_path = os.path.join(filedir, "redis_ai") +os.makedirs(exp_path, exist_ok=True) +exp = Experiment("redis_ai", launcher="slurm", exp_path=exp_path) + +db = exp.create_database(interface="hsn0") + +app_rs = exp.create_run_settings(sys.executable, exe_args = [app_script_name, "--device", device]) +app_rs.set_nodes(1) +app_rs.set_tasks(1) +app = exp.create_model("app", run_settings=app_rs) +app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name]) + +exp.generate(db, app, overwrite=True) + +exp.start(db, app, block=False) + +while True: + if exp.get_status(app)[0] in TERMINAL_STATUSES: + exp.stop(db) + break + if exp.get_status(db)[0] in TERMINAL_STATUSES: + exp.stop(app) + break + time.sleep(5) + +print("Exiting.") \ No newline at end of file diff --git a/ex/high_throughput_inference/standalone_workermanager.py b/ex/high_throughput_inference/standalone_workermanager.py new file mode 100644 index 0000000000..c56e11a7c3 --- /dev/null +++ b/ex/high_throughput_inference/standalone_workermanager.py @@ -0,0 +1,96 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# isort: off +import dragon +from dragon import fli +from dragon.channels import Channel +from dragon.data.ddict.ddict import DDict +from dragon.utils import b64decode, b64encode +from dragon.globalservices.api_setup import connect_to_infrastructure +# isort: on +import argparse +import base64 +import cloudpickle +import pickle +import os + +from smartsim._core.mli.comm.channel.dragonchannel import DragonCommChannel +from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import DragonFeatureStore +from smartsim._core.mli.comm.channel.dragonfli import DragonFLIChannel +from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker +from smartsim._core.mli.infrastructure.control.workermanager import WorkerManager +from smartsim._core.mli.infrastructure.environmentloader import EnvironmentConfigLoader + + +if __name__ == "__main__": + parser = argparse.ArgumentParser("Worker Manager") + parser.add_argument( + "--device", + type=str, + default="gpu", + choices="gpu cpu".split(), + help="Device on which the inference takes place", + ) + parser.add_argument( + "--worker_class", + type=str, + required=True, + help="Serialized class of worker to run", + ) + parser.add_argument( + "--num_workers", type=int, default=1, help="Number of workers to run" + ) + + args = parser.parse_args() + connect_to_infrastructure() + ddict_str = os.environ["SS_DRG_DDICT"] + ddict = DDict.attach(ddict_str) + + to_worker_channel = Channel.make_process_local() + to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) + to_worker_fli_serialized = to_worker_fli.serialize() + ddict["to_worker_fli"] = to_worker_fli_serialized + + torch_worker = cloudpickle.loads(base64.b64decode(args.worker_class.encode('ascii')))() + + dfs = DragonFeatureStore(ddict) + comm_channel = DragonFLIChannel(to_worker_fli_serialized) + + os.environ["SSFeatureStore"] = base64.b64encode(pickle.dumps(dfs)).decode("utf-8") + os.environ["SSQueue"] = base64.b64encode(to_worker_fli_serialized).decode("utf-8") + + config_loader = EnvironmentConfigLoader() + + worker_manager = WorkerManager( + config_loader=config_loader, + worker=torch_worker, + as_service=True, + cooldown=10, + comm_channel_type=DragonCommChannel, + device = args.device, + ) + worker_manager.execute() diff --git a/smartsim/_core/entrypoints/service.py b/smartsim/_core/entrypoints/service.py index e03df6bea1..df9c2bbef6 100644 --- a/smartsim/_core/entrypoints/service.py +++ b/smartsim/_core/entrypoints/service.py @@ -46,7 +46,8 @@ def __init__( :param as_service: Determines if the host will run until shutdown criteria are met or as a run-once instance :param cooldown: Period of time to allow service to run before automatic - shutdown, in seconds. A non-zero, positive integer.""" + shutdown, in seconds. A non-zero, positive integer. + :param loop_delay: delay between iterations of the event loop""" self._as_service = as_service """If the service should run until shutdown function returns True""" self._cooldown = abs(cooldown) @@ -102,6 +103,23 @@ def execute(self) -> None: running = True cooldown_start: t.Optional[datetime.datetime] = None + headers = [ + "batch_size", + "w_deserialize", + "w_fetch_model", + "w_load_model", + "w_fetch_input", + "w_transform_input", + "w_execute", + "w_transform_output", + "w_assign_output", + "w_build_reply", + "w_serialize_resp", + "w_send", + ] + + print(",".join(headers)) + while running: self._on_iteration() diff --git a/smartsim/_core/launcher/dragon/dragonBackend.py b/smartsim/_core/launcher/dragon/dragonBackend.py index 2456606623..dcc5c8392b 100644 --- a/smartsim/_core/launcher/dragon/dragonBackend.py +++ b/smartsim/_core/launcher/dragon/dragonBackend.py @@ -36,8 +36,10 @@ # pylint: disable=import-error # isort: off +import dragon.data.ddict.ddict as dragon_ddict import dragon.infrastructure.connection as dragon_connection import dragon.infrastructure.policy as dragon_policy +import dragon.infrastructure.process_desc as dragon_process_desc import dragon.native.group_state as dragon_group_state import dragon.native.process as dragon_process import dragon.native.process_group as dragon_process_group @@ -187,6 +189,7 @@ def __init__(self, pid: int) -> None: self._view = DragonBackendView(self) logger.debug(self._view.host_desc) + self._infra_ddict: t.Optional[dragon_ddict.DDict] = None @property def hosts(self) -> list[str]: @@ -391,6 +394,22 @@ def _stop_steps(self) -> None: self._group_infos[step_id].status = SmartSimStatus.STATUS_CANCELLED self._group_infos[step_id].return_codes = [-9] + @property + def infra_ddict(self) -> str: + """Create a Dragon distributed dictionary and return its + serialized descriptor + """ + if self._infra_ddict is None: + logger.info("Creating DDict") + self._infra_ddict = dragon_ddict.DDict( + n_nodes=len(self._hosts), total_mem=len(self._hosts) * 1024**3 + ) # todo: parametrize + logger.info("Created DDict") + self._infra_ddict["creation"] = str(time.time()) + logger.info(self._infra_ddict["creation"]) + + return str(self._infra_ddict.serialize()) + def _start_steps(self) -> None: self._heartbeat() with self._queue_lock: @@ -406,6 +425,7 @@ def _start_steps(self) -> None: placement=dragon_policy.Policy.Placement.HOST_NAME, host_name=hosts[0], ) + options = dragon_process_desc.ProcessOptions(make_inf_channels=True) grp = dragon_process_group.ProcessGroup( restart=False, pmi_enabled=request.pmi_enabled, policy=global_policy ) @@ -421,10 +441,15 @@ def _start_steps(self) -> None: target=request.exe, args=request.exe_args, cwd=request.path, - env={**request.current_env, **request.env}, + env={ + **request.current_env, + **request.env, + "SS_DRG_DDICT": self.infra_ddict, + }, stdout=dragon_process.Popen.PIPE, stderr=dragon_process.Popen.PIPE, policy=local_policy, + options=options, ) grp.add_process(nproc=request.tasks_per_node, template=tmp_proc) diff --git a/smartsim/_core/mli/comm/channel/channel.py b/smartsim/_core/mli/comm/channel/channel.py index 201ab9deab..2318896a9b 100644 --- a/smartsim/_core/mli/comm/channel/channel.py +++ b/smartsim/_core/mli/comm/channel/channel.py @@ -41,9 +41,14 @@ def __init__(self, descriptor: t.Union[str, bytes]) -> None: @abstractmethod def send(self, value: bytes) -> None: - """Send a message throuh the underlying communication channel + """Send a message through the underlying communication channel :param value: The value to send""" + @abstractmethod + def recv(self) -> bytes: + """Receieve a message through the underlying communication channel + :returns: the received message""" + @property def descriptor(self) -> bytes: """Return the channel descriptor for the underlying dragon channel""" diff --git a/smartsim/_core/mli/comm/channel/dragonchannel.py b/smartsim/_core/mli/comm/channel/dragonchannel.py index 4fd26861ca..1409747a91 100644 --- a/smartsim/_core/mli/comm/channel/dragonchannel.py +++ b/smartsim/_core/mli/comm/channel/dragonchannel.py @@ -24,16 +24,18 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import typing as t +import sys import smartsim._core.mli.comm.channel.channel as cch from smartsim.log import get_logger logger = get_logger(__name__) -if t.TYPE_CHECKING: +try: import dragon.channels as dch - import dragon.utils as du +except ImportError as exc: + if not "pytest" in sys.modules: + raise exc from None class DragonCommChannel(cch.CommChannelBase): @@ -42,11 +44,17 @@ class DragonCommChannel(cch.CommChannelBase): def __init__(self, key: bytes) -> None: """Initialize the DragonCommChannel instance""" super().__init__(key) - # todo: do we need memory pool information to construct the channel correctly? - self._channel: "dch.Channel" = du.get_channel(key) + self._channel: dch.Channel = dch.Channel.attach(key) def send(self, value: bytes) -> None: """Send a message throuh the underlying communication channel :param value: The value to send""" - logger.debug(f"Channel {self.descriptor.decode('utf-8')} sending message") - self._channel.send_bytes(value) + with self._channel.sendh(timeout=None) as sendh: + sendh.send_bytes(value) + + def recv(self) -> bytes: + """Receieve a message through the underlying communication channel + :returns: the received message""" + with self._channel.recvh(timeout=None) as recvh: + message_bytes: bytes = recvh.recv_bytes(timeout=None) + return message_bytes diff --git a/smartsim/_core/mli/comm/channel/dragonfli.py b/smartsim/_core/mli/comm/channel/dragonfli.py new file mode 100644 index 0000000000..75f8fb4bfc --- /dev/null +++ b/smartsim/_core/mli/comm/channel/dragonfli.py @@ -0,0 +1,69 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# isort: off +from dragon import fli +import dragon.channels as dch + +# isort: on + +import sys +import typing as t + +import smartsim._core.mli.comm.channel.channel as cch +from smartsim.log import get_logger + +logger = get_logger(__name__) + + +class DragonFLIChannel(cch.CommChannelBase): + """Passes messages by writing to a Dragon FLI Channel""" + + def __init__(self, fli_desc: bytes, sender_supplied: bool = True) -> None: + """Initialize the DragonFLIChannel instance""" + super().__init__(fli_desc) + # todo: do we need memory pool information to construct the channel correctly? + self._fli: "fli" = fli.FLInterface.attach(fli_desc) + self._channel: t.Optional["dch"] = ( + dch.Channel.make_process_local() if sender_supplied else None + ) + + def send(self, value: bytes) -> None: + """Send a message through the underlying communication channel + :param value: The value to send""" + with self._fli.sendh(timeout=None, stream_channel=self._channel) as sendh: + sendh.send_bytes(value) + + def recv(self) -> bytes: + """Receieve a message through the underlying communication channel + :returns: the received message""" + with self._fli.recvh(timeout=None) as recvh: + try: + request_bytes: bytes + request_bytes, _ = recvh.recv_bytes(timeout=None) + return request_bytes + except fli.FLIEOT as exc: + return b"" diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index fc87fb1431..c3950fa987 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -24,24 +24,34 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import multiprocessing as mp +import sys + +# isort: off +import dragon +from dragon import fli + +# isort: on + +import time import typing as t import numpy as np -from smartsim._core.entrypoints.service import Service -from smartsim._core.mli.comm.channel.channel import CommChannelBase -from smartsim._core.mli.comm.channel.dragonchannel import DragonCommChannel -from smartsim._core.mli.infrastructure.environmentloader import EnvironmentConfigLoader -from smartsim._core.mli.infrastructure.storage.featurestore import FeatureStore -from smartsim._core.mli.infrastructure.worker.worker import ( +from .....error import SmartSimError +from .....log import get_logger +from ....entrypoints.service import Service +from ...comm.channel.channel import CommChannelBase +from ...comm.channel.dragonchannel import DragonCommChannel +from ...infrastructure.environmentloader import EnvironmentConfigLoader +from ...infrastructure.storage.featurestore import FeatureStore +from ...infrastructure.worker.worker import ( InferenceReply, InferenceRequest, + LoadModelResult, MachineLearningWorkerBase, ) -from smartsim._core.mli.message_handler import MessageHandler -from smartsim._core.mli.mli_schemas.response.response_capnp import Response -from smartsim.log import get_logger +from ...message_handler import MessageHandler +from ...mli_schemas.response.response_capnp import Response if t.TYPE_CHECKING: from dragon.fli import FLInterface @@ -53,7 +63,9 @@ def deserialize_message( - data_blob: bytes, channel_type: t.Type[CommChannelBase] + data_blob: bytes, + channel_type: t.Type[CommChannelBase], + device: t.Literal["cpu", "gpu"], ) -> InferenceRequest: """Deserialize a message from a byte stream into an InferenceRequest :param data_blob: The byte stream to deserialize""" @@ -88,12 +100,6 @@ def deserialize_message( ) output_keys: t.Optional[t.List[str]] = None - # # client example - # msg = Message() - # t = torch.Tensor() - # msg.inputs = [custom_byte_converter(t)] - # mli_client.request_inference(msg) - # # end client input_meta: t.List[t.Any] = [] if request.input.which() == "keys": @@ -203,6 +209,7 @@ def __init__( as_service: bool = False, cooldown: int = 0, comm_channel_type: t.Type[CommChannelBase] = DragonCommChannel, + device: t.Literal["cpu", "gpu"] = "cpu", ) -> None: """Initialize the WorkerManager :param config_loader: Environment config loader that loads the task queue and @@ -215,8 +222,7 @@ def __init__( """ super().__init__(as_service, cooldown) - """a collection of workers the manager is controlling""" - self._task_queue: t.Optional["FLInterface"] = config_loader.get_queue() + self._task_queue: t.Optional[CommChannelBase] = config_loader.get_queue() """the queue the manager monitors for new tasks""" self._feature_store: t.Optional[FeatureStore] = ( config_loader.get_feature_store() @@ -226,6 +232,10 @@ def __init__( """The ML Worker implementation""" self._comm_channel_type = comm_channel_type """The type of communication channel to construct for callbacks""" + self._device = device + """Device on which workers need to run""" + self._cached_models: dict[str, t.Any] = {} + """Dictionary of previously loaded models""" def _validate_request(self, request: InferenceRequest) -> bool: """Ensure the request can be processed. @@ -267,42 +277,112 @@ def _on_iteration(self) -> None: logger.warning("No queue to check for tasks") return + timings = [] # timing # perform default deserialization of the message envelope - receiver = self._task_queue.recvh(use_main_as_stream_channel=True) - request_bytes, _ = receiver.recv_bytes() + request_bytes: bytes = self._task_queue.recv() - request = deserialize_message(request_bytes, self._comm_channel_type) + interm = time.perf_counter() # timing + request = deserialize_message( + request_bytes, self._comm_channel_type, self._device + ) if not self._validate_request(request): return + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing + reply = InferenceReply() - try: - fetch_model_result = self._worker.fetch_model(request, self._feature_store) - except Exception as e: - exception_handler(e, request.callback, "fetching the model", reply) - return + if not request.raw_model: + if request.model_key is None: + response = build_failure_reply("fail", "Could not read model key.") + serialized_resp = MessageHandler.serialize_response(response) # type: ignore + if request.callback: + request.callback.send(serialized_resp) + return + if request.model_key in self._cached_models: + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing + model_result = LoadModelResult(self._cached_models[request.model_key]) + + else: + fetch_model_result = None + while True: + try: + interm = time.perf_counter() # timing + fetch_model_result = self._worker.fetch_model( + request, self._feature_store + ) + except KeyError: + time.sleep(0.1) + except Exception as e: + exception_handler( + e, request.callback, "fetching the model", reply + ) + break + return + + if fetch_model_result is None: + response = build_failure_reply( + "fail", "Could not retrieve model from feature store." + ) + serialized_resp = MessageHandler.serialize_response(response) # type: ignore + if request.callback: + request.callback.send(serialized_resp) + return + else: + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing + try: + model_result = self._worker.load_model( + request, + fetch_result=fetch_model_result, + device=self._device, + ) + self._cached_models[request.model_key] = model_result.model + except Exception as e: + exception_handler( + e, request.callback, "loading the model", reply + ) + return - try: - model_result = self._worker.load_model(request, fetch_model_result) - except Exception as e: - exception_handler(e, request.callback, "loading the model", reply) - return + else: + try: + fetch_model_result = self._worker.fetch_model( + request, self._feature_store + ) + except Exception as e: + exception_handler(e, request.callback, "fetching the model", reply) + return + try: + model_result = self._worker.load_model( + request, fetch_result=fetch_model_result, device=self._device + ) + except Exception as e: + exception_handler(e, request.callback, "loading the model", reply) + return + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing try: fetch_input_result = self._worker.fetch_inputs(request, self._feature_store) except Exception as e: exception_handler(e, request.callback, "fetching the inputs", reply) return + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing try: transformed_input = self._worker.transform_input( - request, fetch_input_result + request, fetch_input_result, self._device ) except Exception as e: exception_handler(e, request.callback, "transforming the input", reply) return + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing + try: execute_result = self._worker.execute( request, model_result, transformed_input @@ -311,12 +391,18 @@ def _on_iteration(self) -> None: exception_handler(e, request.callback, "executing", reply) return + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing try: - transformed_output = self._worker.transform_output(request, execute_result) + transformed_output = self._worker.transform_output( + request, execute_result, self._device + ) except Exception as e: exception_handler(e, request.callback, "transforming the output", reply) return + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing if request.output_keys: try: reply.output_keys = self._worker.place_output( @@ -330,18 +416,31 @@ def _on_iteration(self) -> None: else: reply.outputs = transformed_output.outputs - if reply.outputs is None or not reply.outputs: - response = build_failure_reply("fail", "Outputs not found.") + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing + if reply.outputs is None or not reply.outputs: + response = build_failure_reply("fail", "no-results") else: reply.status_enum = "complete" reply.message = "Success" response = build_reply(reply) + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing + serialized_resp = MessageHandler.serialize_response(response) # type: ignore + + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing if request.callback: request.callback.send(serialized_resp) + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing + + print(" ".join(str(time) for time in timings)) # timing + def _can_shutdown(self) -> bool: """Return true when the criteria to shut down the service are met.""" # todo: determine shutdown criteria diff --git a/smartsim/_core/mli/infrastructure/environmentloader.py b/smartsim/_core/mli/infrastructure/environmentloader.py index 267b668f63..9f6770623d 100644 --- a/smartsim/_core/mli/infrastructure/environmentloader.py +++ b/smartsim/_core/mli/infrastructure/environmentloader.py @@ -31,6 +31,7 @@ from dragon.fli import FLInterface # pylint: disable=all +from smartsim._core.mli.comm.channel.dragonfli import DragonFLIChannel from smartsim._core.mli.infrastructure.storage.featurestore import FeatureStore @@ -41,10 +42,12 @@ class EnvironmentConfigLoader: """ def __init__(self) -> None: - self._feature_store_descriptor = os.getenv("SSFeatureStore", None) - self._queue_descriptor = os.getenv("SSQueue", None) + self._feature_store_descriptor: t.Optional[str] = os.getenv( + "SSFeatureStore", None + ) + self._queue_descriptor: t.Optional[str] = os.getenv("SSQueue", None) self.feature_store: t.Optional[FeatureStore] = None - self.queue: t.Optional["FLInterface"] = None + self.queue: t.Optional[DragonFLIChannel] = None def get_feature_store(self) -> t.Optional[FeatureStore]: """Loads the Feature Store previously set in SSFeatureStore""" @@ -54,8 +57,11 @@ def get_feature_store(self) -> t.Optional[FeatureStore]: ) return self.feature_store - def get_queue(self) -> t.Optional["FLInterface"]: + def get_queue(self, sender_supplied: bool = True) -> t.Optional[DragonFLIChannel]: """Returns the Queue previously set in SSQueue""" if self._queue_descriptor is not None: - self.queue = FLInterface.attach(base64.b64decode(self._queue_descriptor)) + self.queue = DragonFLIChannel( + fli_desc=base64.b64decode(self._queue_descriptor), + sender_supplied=sender_supplied, + ) return self.queue diff --git a/smartsim/_core/mli/infrastructure/storage/dragonfeaturestore.py b/smartsim/_core/mli/infrastructure/storage/dragonfeaturestore.py index 8153255d0a..af592ed0ab 100644 --- a/smartsim/_core/mli/infrastructure/storage/dragonfeaturestore.py +++ b/smartsim/_core/mli/infrastructure/storage/dragonfeaturestore.py @@ -44,27 +44,28 @@ def __init__(self, storage: "DDict") -> None: """Initialize the DragonFeatureStore instance""" self._storage = storage - def __getitem__(self, key: str) -> t.Any: + def __getitem__(self, key: str) -> t.Union[str, bytes]: """Retrieve an item using key :param key: Unique key of an item to retrieve from the feature store""" - key_ = key.encode("utf-8") try: - return self._storage[key_] + value: t.Union[str, bytes] = self._storage[key] + return value + except KeyError as ex: + raise ex except Exception as ex: # note: explicitly avoid round-trip to check for key existence - raise sse.SmartSimError(f"{key} not found in feature store") from ex + raise sse.SmartSimError( + f"Could not get value for existing key {key}, error:\n{ex}" + ) from ex - def __setitem__(self, key: str, value: bytes) -> None: + def __setitem__(self, key: str, value: t.Union[str, bytes]) -> None: """Assign a value using key :param key: Unique key of an item to set in the feature store :param value: Value to persist in the feature store""" - key_ = key.encode("utf-8") - self._storage[key_] = value + self._storage[key] = value - def __contains__(self, key: t.Union[str, bytes]) -> bool: + def __contains__(self, key: str) -> bool: """Membership operator to test for a key existing within the feature store. Return `True` if the key is found, `False` otherwise :param key: Unique key of an item to retrieve from the feature store""" - if isinstance(key, str): - key = key.encode("utf-8") return key in self._storage diff --git a/smartsim/_core/mli/infrastructure/storage/featurestore.py b/smartsim/_core/mli/infrastructure/storage/featurestore.py index ec4086b732..553e13b10f 100644 --- a/smartsim/_core/mli/infrastructure/storage/featurestore.py +++ b/smartsim/_core/mli/infrastructure/storage/featurestore.py @@ -24,6 +24,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import typing as t from abc import ABC, abstractmethod @@ -32,12 +33,12 @@ class FeatureStore(ABC): values from a feature store implementation""" @abstractmethod - def __getitem__(self, key: str) -> bytes: + def __getitem__(self, key: str) -> t.Union[str, bytes]: """Retrieve an item using key :param key: Unique key of an item to retrieve from the feature store""" @abstractmethod - def __setitem__(self, key: str, value: bytes) -> None: + def __setitem__(self, key: str, value: t.Union[str, bytes]) -> None: """Assign a value using key :param key: Unique key of an item to set in the feature store :param value: Value to persist in the feature store""" diff --git a/smartsim/_core/mli/infrastructure/worker/torch_worker.py b/smartsim/_core/mli/infrastructure/worker/torch_worker.py new file mode 100644 index 0000000000..a4e725ab99 --- /dev/null +++ b/smartsim/_core/mli/infrastructure/worker/torch_worker.py @@ -0,0 +1,119 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io + +import numpy as np +import torch + +from .....error import SmartSimError +from .....log import get_logger +from ...mli_schemas.tensor import tensor_capnp +from .worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + InferenceRequest, + LoadModelResult, + MachineLearningWorkerBase, + TransformInputResult, + TransformOutputResult, +) + +logger = get_logger(__name__) + + +class TorchWorker(MachineLearningWorkerBase): + """A worker that executes a PyTorch model.""" + + @staticmethod + def load_model( + request: InferenceRequest, fetch_result: FetchModelResult, device: str + ) -> LoadModelResult: + if fetch_result.model_bytes: + model_bytes = fetch_result.model_bytes + elif request.raw_model and request.raw_model.data: + model_bytes = request.raw_model.data + else: + raise ValueError("Unable to load model without reference object") + + device_to_torch = {"cpu": "cpu", "gpu": "cuda"} + device = device_to_torch[device] + buffer = io.BytesIO(initial_bytes=model_bytes) + model = torch.jit.load(buffer, map_location=device) # type: ignore + result = LoadModelResult(model) + return result + + @staticmethod + def transform_input( + request: InferenceRequest, fetch_result: FetchInputResult, device: str + ) -> TransformInputResult: + result = [] + + device_to_torch = {"cpu": "cpu", "gpu": "cuda"} + device = device_to_torch[device] + if fetch_result.meta is None: + raise ValueError("Cannot reconstruct tensor without meta information") + for item, item_meta in zip(fetch_result.inputs, fetch_result.meta): + tensor_desc: tensor_capnp.TensorDescriptor = item_meta + result.append( + torch.tensor(np.frombuffer(item, dtype=str(tensor_desc.dataType))) + .to(device) + .reshape(tuple(dim for dim in tensor_desc.dimensions)) + ) + return TransformInputResult(result) + # return data # note: this fails copy test! + + @staticmethod + def execute( + request: InferenceRequest, + load_result: LoadModelResult, + transform_result: TransformInputResult, + ) -> ExecuteResult: + if not load_result.model: + raise SmartSimError("Model must be loaded to execute") + + model: torch.nn.Module = load_result.model + model.eval() + results = [model(tensor).detach() for tensor in transform_result.transformed] + + execute_result = ExecuteResult(results) + return execute_result + + @staticmethod + def transform_output( + request: InferenceRequest, + execute_result: ExecuteResult, + result_device: str, + ) -> TransformOutputResult: + if result_device != "cpu": + transformed = [item.to("cpu") for item in execute_result.predictions] + # todo: need the shape from latest schemas added here. + return TransformOutputResult(transformed, None, "c", "float32") # fixme + + return TransformOutputResult( + execute_result.predictions, None, "c", "float32" + ) # fixme diff --git a/smartsim/_core/mli/infrastructure/worker/worker.py b/smartsim/_core/mli/infrastructure/worker/worker.py index 77781f2286..3d3b36fbdf 100644 --- a/smartsim/_core/mli/infrastructure/worker/worker.py +++ b/smartsim/_core/mli/infrastructure/worker/worker.py @@ -27,11 +27,11 @@ import typing as t from abc import ABC, abstractmethod -import smartsim.error as sse -from smartsim._core.mli.comm.channel.channel import CommChannelBase -from smartsim._core.mli.infrastructure.storage.featurestore import FeatureStore -from smartsim._core.mli.mli_schemas.model.model_capnp import Model -from smartsim.log import get_logger +from .....error import SmartSimError +from .....log import get_logger +from ...comm.channel.channel import CommChannelBase +from ...infrastructure.storage.featurestore import FeatureStore +from ...mli_schemas.model.model_capnp import Model if t.TYPE_CHECKING: from smartsim._core.mli.mli_schemas.response.response_capnp import StatusEnum @@ -110,23 +110,23 @@ def __init__(self, result: t.Any) -> None: class FetchInputResult: """A wrapper around fetched inputs""" - def __init__(self, result: t.List[bytes]) -> None: + def __init__(self, result: t.List[bytes], meta: t.Optional[t.List[t.Any]]) -> None: """Initialize the object""" self.inputs = result + self.meta = meta class TransformOutputResult: """A wrapper around inference results transformed for transmission""" def __init__( - self, result: t.Any, shape: t.List[int], order: str, dtype: str + self, result: t.Any, shape: t.Optional[t.List[int]], order: str, dtype: str ) -> None: """Initialize the OutputTransformResult""" self.outputs = result self.shape = shape self.order = order self.dtype = dtype - # todo: determine if each output must have an individual (shape, order, dtype) class CreateInputBatchResult: @@ -142,7 +142,7 @@ class FetchModelResult: def __init__(self, result: bytes) -> None: """Initialize the object""" - self.model_bytes = result + self.model_bytes: bytes = result class MachineLearningWorkerCore: @@ -156,8 +156,6 @@ def fetch_model( :param request: The request that triggered the pipeline :param feature_store: The feature store used for persistence :return: Raw bytes of the model""" - if not feature_store: - raise ValueError("Feature store is required for model retrieval") if request.raw_model: # Should we cache model in the feature store? @@ -166,17 +164,20 @@ def fetch_model( # short-circuit and return the directly supplied model return FetchModelResult(request.raw_model.data) + if not feature_store: + raise ValueError("Feature store is required for model retrieval") + if not request.model_key: - raise sse.SmartSimError( + raise SmartSimError( "Key must be provided to retrieve model from feature store" ) try: - raw_bytes = feature_store[request.model_key] + raw_bytes: bytes = t.cast(bytes, feature_store[request.model_key]) return FetchModelResult(raw_bytes) except FileNotFoundError as ex: logger.exception(ex) - raise sse.SmartSimError( + raise SmartSimError( f"Model could not be retrieved with key {request.model_key}" ) from ex @@ -189,24 +190,27 @@ def fetch_inputs( :param request: The request that triggered the pipeline :param feature_store: The feature store used for persistence :return: the fetched input""" + + if request.raw_inputs: + return FetchInputResult(request.raw_inputs, request.input_meta) + if not feature_store: - raise ValueError("Feature store is required for input retrieval") + raise ValueError("No input and no feature store provided") if request.input_keys: data: t.List[bytes] = [] for input_ in request.input_keys: try: - tensor_bytes = feature_store[input_] + tensor_bytes = t.cast(bytes, feature_store[input_]) data.append(tensor_bytes) except KeyError as ex: logger.exception(ex) - raise sse.SmartSimError( + raise SmartSimError( f"Model could not be retrieved with key {input_}" ) from ex - return FetchInputResult(data) - - if request.raw_inputs: - return FetchInputResult(request.raw_inputs) + return FetchInputResult( + data, None + ) # fixme: need to get both tensor and descriptor raise ValueError("No input source") @@ -254,32 +258,26 @@ class MachineLearningWorkerBase(MachineLearningWorkerCore, ABC): """Abstrct base class providing contract for a machine learning worker implementation.""" - # @staticmethod - # @abstractmethod - # def deserialize(request: InferenceRequest) -> InferenceRequest: - # """Given a collection of data serialized to bytes, convert the bytes - # to a proper representation used by the ML backend - # :param data_blob: inference request as a byte-serialized blob - # :return: InferenceRequest deserialized from the input""" - @staticmethod @abstractmethod def load_model( - request: InferenceRequest, fetch_result: FetchModelResult + request: InferenceRequest, fetch_result: FetchModelResult, device: str ) -> LoadModelResult: """Given a loaded MachineLearningModel, ensure it is loaded into device memory :param request: The request that triggered the pipeline + :param device: The device on which the model must be placed :return: ModelLoadResult wrapping the model loaded for the request""" @staticmethod @abstractmethod def transform_input( - request: InferenceRequest, fetch_result: FetchInputResult + request: InferenceRequest, fetch_result: FetchInputResult, device: str ) -> TransformInputResult: """Given a collection of data, perform a transformation on the data :param request: The request that triggered the pipeline :param fetch_result: Raw output from fetching inputs out of a feature store + :param device: The device on which the transformed input must be placed :return: The transformed inputs wrapped in a InputTransformResult""" @staticmethod @@ -298,20 +296,11 @@ def execute( @staticmethod @abstractmethod def transform_output( - request: InferenceRequest, - execute_result: ExecuteResult, + request: InferenceRequest, execute_result: ExecuteResult, result_device: str ) -> TransformOutputResult: """Given inference results, perform transformations required to transmit results to the requestor. :param request: The request that triggered the pipeline :param execute_result: The result of inference wrapped in an ExecuteResult + :param result_device: The device on which the result of inference is placed :return:""" - - # @staticmethod - # @abstractmethod - # def serialize_reply( - # request: InferenceRequest, results: OutputTransformResult - # ) -> bytes: - # """Given an output, serialize to bytes for transport - # :param reply: The result of the inference pipeline - # :return: a byte-serialized version of the reply""" diff --git a/smartsim/_core/mli/message_handler.py b/smartsim/_core/mli/message_handler.py index 16cb242b7c..bcf1cfdf14 100644 --- a/smartsim/_core/mli/message_handler.py +++ b/smartsim/_core/mli/message_handler.py @@ -396,7 +396,9 @@ def deserialize_request(request_bytes: bytes) -> request_capnp.Request: :param request_bytes: Bytes to be deserialized into a Request """ - bytes_message = request_capnp.Request.from_bytes(request_bytes) + bytes_message = request_capnp.Request.from_bytes( + request_bytes, traversal_limit_in_words=2**63 + ) with bytes_message as message: return message @@ -489,7 +491,7 @@ def _assign_custom_response_attributes( response.customAttributes.tf = custom_attrs # type: ignore else: raise ValueError("""Invalid custom attribute class name. - Expected 'TensorFlowResponseAttributes' or + Expected 'TensorFlowResponseAttributes' or 'TorchResponseAttributes'.""") except Exception as e: raise ValueError("Error assigning custom attributes to response.") from e @@ -534,7 +536,9 @@ def deserialize_response(response_bytes: bytes) -> response_capnp.Response: """ Deserializes a serialized response message. """ - bytes_message = response_capnp.Response.from_bytes(response_bytes) + bytes_message = response_capnp.Response.from_bytes( + response_bytes, traversal_limit_in_words=2**63 + ) with bytes_message as message: return message diff --git a/tests/dragon/test_environment_loader.py b/tests/dragon/test_environment_loader.py index d339fec885..00db0a9d32 100644 --- a/tests/dragon/test_environment_loader.py +++ b/tests/dragon/test_environment_loader.py @@ -64,10 +64,9 @@ def test_environment_loader_attach_FLI(content, monkeypatch): config = EnvironmentConfigLoader() config_queue = config.get_queue() - new_sender = config_queue.sendh(use_main_as_stream_channel=True) - new_sender.send_bytes(content) + new_sender = config_queue.send(content) - old_recv = queue.recvh(use_main_as_stream_channel=True) + old_recv = queue.recvh() result, _ = old_recv.recv_bytes() assert result == content @@ -81,7 +80,7 @@ def test_environment_loader_serialize_FLI(monkeypatch): config = EnvironmentConfigLoader() config_queue = config.get_queue() - assert config_queue.serialize() == queue.serialize() + assert config_queue._fli.serialize() == queue.serialize() def test_environment_loader_FLI_fails(monkeypatch): diff --git a/tests/dragon/test_error_handling.py b/tests/dragon/test_error_handling.py index 4ced3e0ff2..2aa81e666f 100644 --- a/tests/dragon/test_error_handling.py +++ b/tests/dragon/test_error_handling.py @@ -89,8 +89,7 @@ def setup_worker_manager(test_dir, monkeypatch: pytest.MonkeyPatch): test_dir, model, [tensor_key], [tensor_key], [], None ) ser_request = MessageHandler.serialize_request(request) - new_sender = worker_manager._task_queue.sendh(use_main_as_stream_channel=True) - new_sender.send_bytes(ser_request) + new_sender = worker_manager._task_queue.send(ser_request) return worker_manager, integrated_worker @@ -163,7 +162,7 @@ def test_pipeline_stage_errors_handled( monkeypatch.setattr( integrated_worker, "fetch_inputs", - MagicMock(return_value=FetchInputResult([b"result_bytes"])), + MagicMock(return_value=FetchInputResult([b"result_bytes"], None)), ) if stage not in ["fetch_model", "load_model", "fetch_inputs", "transform_input"]: monkeypatch.setattr( diff --git a/tests/dragon/utils/channel.py b/tests/dragon/utils/channel.py index 4bc2014ea3..2523b1a6bf 100644 --- a/tests/dragon/utils/channel.py +++ b/tests/dragon/utils/channel.py @@ -57,3 +57,8 @@ def send(self, value: bytes) -> None: f"Channel {self.descriptor.decode('utf-8')} sending message to {self._file_path}" ) self._file_path.write_bytes(value) + + def recv(self) -> bytes: + """Receive a message through the underlying communication channel + :returns: the received message""" + ... diff --git a/tests/mli/test_torch_worker.py b/tests/mli/test_torch_worker.py new file mode 100644 index 0000000000..0b1cd4ccf3 --- /dev/null +++ b/tests/mli/test_torch_worker.py @@ -0,0 +1,173 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import io + +import numpy as np +import pytest +import torch +from torch import nn +from torch.nn import functional as F + +from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker +from smartsim._core.mli.infrastructure.worker.worker import ( + ExecuteResult, + FetchInputResult, + FetchModelResult, + InferenceRequest, + LoadModelResult, + TransformInputResult, +) +from smartsim._core.mli.message_handler import MessageHandler +from smartsim.log import get_logger + +logger = get_logger(__name__) +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + + +# simple MNIST in PyTorch +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) + + def forward(self, x): + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) + x = self.fc2(x) + output = F.log_softmax(x, dim=1) + return output + + +torch_device = {"cpu": "cpu", "gpu": "cuda"} + + +def get_batch() -> torch.Tensor: + return torch.rand(20, 1, 28, 28) + + +def create_torch_model(): + n = Net() + example_forward_input = get_batch() + module = torch.jit.trace(n, example_forward_input) + model_buffer = io.BytesIO() + torch.jit.save(module, model_buffer) + return model_buffer.getvalue() + + +def get_request() -> InferenceRequest: + + tensors = [get_batch() for _ in range(2)] + serialized_tensors = [ + MessageHandler.build_tensor(tensor.numpy(), "c", "float32", list(tensor.shape)) + for tensor in tensors + ] + + return InferenceRequest( + model_key="model", + callback=None, + raw_inputs=[s_tensor.blob for s_tensor in serialized_tensors], + input_keys=None, + input_meta=[s_tensor.tensorDescriptor for s_tensor in serialized_tensors], + output_keys=None, + raw_model=create_torch_model(), + batch_size=0, + ) + + +sample_request: InferenceRequest = get_request() +worker = TorchWorker() + + +def test_load_model(mlutils) -> None: + fetch_model_result = FetchModelResult(sample_request.raw_model) + load_model_result = worker.load_model( + sample_request, fetch_model_result, mlutils.get_test_device().lower() + ) + + assert load_model_result.model( + get_batch().to(torch_device[mlutils.get_test_device().lower()]) + ).shape == torch.Size((20, 10)) + + +def test_transform_input(mlutils) -> None: + fetch_input_result = FetchInputResult( + sample_request.raw_inputs, sample_request.input_meta + ) + + transform_input_result = worker.transform_input( + sample_request, fetch_input_result, mlutils.get_test_device().lower() + ) + + assert all( + transformed.shape == get_batch().shape + for transformed in transform_input_result.transformed + ) + + +def test_execute(mlutils) -> None: + load_model_result = LoadModelResult( + Net().to(torch_device[mlutils.get_test_device().lower()]) + ) + transform_result = TransformInputResult( + [ + get_batch().to(torch_device[mlutils.get_test_device().lower()]) + for _ in range(2) + ] + ) + + execute_result = worker.execute(sample_request, load_model_result, transform_result) + + assert all( + result.shape == torch.Size((20, 10)) for result in execute_result.predictions + ) + + +def test_transform_output(mlutils): + execute_result = ExecuteResult([torch.rand((20, 10)) for _ in range(2)]) + + transformed_output = worker.transform_output( + sample_request, execute_result, torch_device[mlutils.get_test_device().lower()] + ) + + assert transformed_output.outputs == execute_result.predictions + assert transformed_output.shape == None + assert transformed_output.order == "c" + assert transformed_output.dtype == "float32" diff --git a/tests/mli/test_worker_manager.py b/tests/mli/test_worker_manager.py index 7fd2c9aad4..df4b0a637f 100644 --- a/tests/mli/test_worker_manager.py +++ b/tests/mli/test_worker_manager.py @@ -29,11 +29,10 @@ import multiprocessing as mp import pathlib import time -import typing as t import pytest -import torch +torch = pytest.importorskip("torch") dragon = pytest.importorskip("dragon") from smartsim._core.mli.infrastructure.control.workermanager import ( diff --git a/tests/test_dragon_backend.py b/tests/test_dragon_backend.py index a510f660a5..f284f38d99 100644 --- a/tests/test_dragon_backend.py +++ b/tests/test_dragon_backend.py @@ -103,6 +103,16 @@ def get_mock_backend(monkeypatch: pytest.MonkeyPatch) -> "DragonBackend": "dragon.infrastructure.connection", MagicMock(), ) + monkeypatch.setitem( + sys.modules, + "dragon.infrastructure.process_desc", + MagicMock(), + ) + monkeypatch.setitem( + sys.modules, + "dragon.data.ddict.ddict", + MagicMock(), + ) monkeypatch.setitem( sys.modules, "dragon.infrastructure.policy", From 7451b2a55d670cf29090ff2afbe7fe29833985e7 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Tue, 16 Jul 2024 12:03:59 -0500 Subject: [PATCH 43/57] Revert "merge" This reverts commit 0af26a1c83069734a3211a5f05a5bfaf2e14d7d6. revert --- doc/changelog.md | 1 - ex/high_throughput_inference/mli_driver.py | 50 ----- ex/high_throughput_inference/mock_app.py | 195 ------------------ .../mock_app_redis.py | 88 -------- ex/high_throughput_inference/redis_driver.py | 65 ------ .../standalone_workermanager.py | 96 --------- smartsim/_core/entrypoints/service.py | 20 +- .../_core/launcher/dragon/dragonBackend.py | 27 +-- smartsim/_core/mli/comm/channel/channel.py | 7 +- .../_core/mli/comm/channel/dragonchannel.py | 22 +- smartsim/_core/mli/comm/channel/dragonfli.py | 69 ------- .../infrastructure/control/workermanager.py | 171 ++++----------- .../mli/infrastructure/environmentloader.py | 16 +- .../storage/dragonfeaturestore.py | 21 +- .../infrastructure/storage/featurestore.py | 5 +- .../mli/infrastructure/worker/torch_worker.py | 119 ----------- .../_core/mli/infrastructure/worker/worker.py | 73 ++++--- smartsim/_core/mli/message_handler.py | 10 +- tests/dragon/test_environment_loader.py | 7 +- tests/dragon/test_error_handling.py | 5 +- tests/dragon/utils/channel.py | 5 - tests/mli/test_torch_worker.py | 173 ---------------- tests/mli/test_worker_manager.py | 3 +- tests/test_dragon_backend.py | 10 - 24 files changed, 117 insertions(+), 1141 deletions(-) delete mode 100644 ex/high_throughput_inference/mli_driver.py delete mode 100644 ex/high_throughput_inference/mock_app.py delete mode 100644 ex/high_throughput_inference/mock_app_redis.py delete mode 100644 ex/high_throughput_inference/redis_driver.py delete mode 100644 ex/high_throughput_inference/standalone_workermanager.py delete mode 100644 smartsim/_core/mli/comm/channel/dragonfli.py delete mode 100644 smartsim/_core/mli/infrastructure/worker/torch_worker.py delete mode 100644 tests/mli/test_torch_worker.py diff --git a/doc/changelog.md b/doc/changelog.md index 28de6e8f95..83b3ce6b71 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -13,7 +13,6 @@ Jump to: Description -- Add TorchWorker first implementation and mock inference app example - Add error handling in Worker Manager pipeline - Add EnvironmentConfigLoader for ML Worker Manager - Add Model schema with model metadata included diff --git a/ex/high_throughput_inference/mli_driver.py b/ex/high_throughput_inference/mli_driver.py deleted file mode 100644 index 6da559aa6f..0000000000 --- a/ex/high_throughput_inference/mli_driver.py +++ /dev/null @@ -1,50 +0,0 @@ - - -import os -import base64 -import cloudpickle -import sys -from smartsim import Experiment -from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker -from smartsim.status import TERMINAL_STATUSES -import time -import typing as t - -device = "gpu" -filedir = os.path.dirname(__file__) -worker_manager_script_name = os.path.join(filedir, "standalone_workermanager.py") -app_script_name = os.path.join(filedir, "mock_app.py") -model_name = os.path.join(filedir, f"resnet50.{device.upper()}.pt") - -transport: t.Literal["hsta", "tcp"] = "hsta" - -os.environ["SMARTSIM_DRAGON_TRANSPORT"] = transport - -exp_path = os.path.join(filedir, f"MLI_proto_{transport.upper()}") -os.makedirs(exp_path, exist_ok=True) -exp = Experiment("MLI_proto", launcher="dragon", exp_path=exp_path) - -torch_worker_str = base64.b64encode(cloudpickle.dumps(TorchWorker)).decode("ascii") - -worker_manager_rs = exp.create_run_settings(sys.executable, [worker_manager_script_name, "--device", device, "--worker_class", torch_worker_str]) -worker_manager = exp.create_model("worker_manager", run_settings=worker_manager_rs) -worker_manager.attach_generator_files(to_copy=[worker_manager_script_name]) - -app_rs = exp.create_run_settings(sys.executable, exe_args = [app_script_name, "--device", device]) -app = exp.create_model("app", run_settings=app_rs) -app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name]) - - -exp.generate(worker_manager, app, overwrite=True) -exp.start(worker_manager, app, block=False) - -while True: - if exp.get_status(app)[0] in TERMINAL_STATUSES: - exp.stop(worker_manager) - break - if exp.get_status(worker_manager)[0] in TERMINAL_STATUSES: - exp.stop(app) - break - time.sleep(5) - -print("Exiting.") \ No newline at end of file diff --git a/ex/high_throughput_inference/mock_app.py b/ex/high_throughput_inference/mock_app.py deleted file mode 100644 index 45246db2e5..0000000000 --- a/ex/high_throughput_inference/mock_app.py +++ /dev/null @@ -1,195 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -# isort: off -import dragon -from dragon import fli -from dragon.channels import Channel -import dragon.channels -from dragon.data.ddict.ddict import DDict -from dragon.globalservices.api_setup import connect_to_infrastructure -from dragon.utils import b64decode, b64encode - -# isort: on - -import argparse -import io -import numpy -import os -import time -import torch -import numbers - -from collections import OrderedDict -from smartsim._core.mli.message_handler import MessageHandler -from smartsim.log import get_logger - -logger = get_logger("App") - -class ProtoClient: - def __init__(self, timing_on: bool): - connect_to_infrastructure() - ddict_str = os.environ["SS_DRG_DDICT"] - self._ddict = DDict.attach(ddict_str) - to_worker_fli_str = None - while to_worker_fli_str is None: - try: - to_worker_fli_str = self._ddict["to_worker_fli"] - self._to_worker_fli = fli.FLInterface.attach(to_worker_fli_str) - except KeyError: - time.sleep(1) - self._from_worker_ch = Channel.make_process_local() - self._from_worker_ch_serialized = self._from_worker_ch.serialize() - self._to_worker_ch = Channel.make_process_local() - - self._start = None - self._interm = None - self._timings: OrderedDict[str, list[numbers.Number]] = OrderedDict() - self._timing_on = timing_on - - def _add_label_to_timings(self, label: str): - if label not in self._timings: - self._timings[label] = [] - - @staticmethod - def _format_number(number: numbers.Number): - return f"{number:0.4e}" - - def start_timings(self, batch_size: int): - if self._timing_on: - self._add_label_to_timings("batch_size") - self._timings["batch_size"].append(batch_size) - self._start = time.perf_counter() - self._interm = time.perf_counter() - - def end_timings(self): - if self._timing_on: - self._add_label_to_timings("total_time") - self._timings["total_time"].append(self._format_number(time.perf_counter()-self._start)) - - def measure_time(self, label: str): - if self._timing_on: - self._add_label_to_timings(label) - self._timings[label].append(self._format_number(time.perf_counter()-self._interm)) - self._interm = time.perf_counter() - - def print_timings(self, to_file: bool = False): - print(" ".join(self._timings.keys())) - value_array = numpy.array([value for value in self._timings.values()], dtype=float) - value_array = numpy.transpose(value_array) - for i in range(value_array.shape[0]): - print(" ".join(self._format_number(value) for value in value_array[i])) - if to_file: - numpy.save("timings.npy", value_array) - numpy.savetxt("timings.txt", value_array) - - - def run_model(self, model: bytes | str, batch: torch.Tensor): - self.start_timings(batch.shape[0]) - built_tensor = MessageHandler.build_tensor( - batch.numpy(), "c", "float32", list(batch.shape)) - self.measure_time("build_tensor") - built_model = None - if isinstance(model, str): - model_arg = MessageHandler.build_model_key(model) - else: - model_arg = MessageHandler.build_model(model, "resnet-50", "1.0") - request = MessageHandler.build_request( - reply_channel=self._from_worker_ch_serialized, - model= model_arg, - inputs=[built_tensor], - outputs=[], - output_descriptors=[], - custom_attributes=None, - ) - self.measure_time("build_request") - request_bytes = MessageHandler.serialize_request(request) - self.measure_time("serialize_request") - with self._to_worker_fli.sendh(timeout=None, stream_channel=self._to_worker_ch) as to_sendh: - to_sendh.send_bytes(request_bytes) - logger.info(f"Message size: {len(request_bytes)} bytes") - - self.measure_time("send") - with self._from_worker_ch.recvh(timeout=None) as from_recvh: - resp = from_recvh.recv_bytes(timeout=None) - self.measure_time("receive") - response = MessageHandler.deserialize_response(resp) - self.measure_time("deserialize_response") - result = torch.from_numpy( - numpy.frombuffer( - response.result.data[0].blob, - dtype=str(response.result.data[0].tensorDescriptor.dataType), - ) - ) - self.measure_time("deserialize_tensor") - - self.end_timings() - return result - - def set_model(self, key: str, model: bytes): - self._ddict[key] = model - - -class ResNetWrapper(): - def __init__(self, name: str, model: str): - self._model = torch.jit.load(model) - self._name = name - buffer = io.BytesIO() - scripted = torch.jit.trace(self._model, self.get_batch()) - torch.jit.save(scripted, buffer) - self._serialized_model = buffer.getvalue() - - def get_batch(self, batch_size: int=32): - return torch.randn((batch_size, 3, 224, 224), dtype=torch.float32) - - @property - def model(self): - return self._serialized_model - - @property - def name(self): - return self._name - -if __name__ == "__main__": - - parser = argparse.ArgumentParser("Mock application") - parser.add_argument("--device", default="cpu") - args = parser.parse_args() - - resnet = ResNetWrapper("resnet50", f"resnet50.{args.device.upper()}.pt") - - client = ProtoClient(timing_on=True) - client.set_model(resnet.name, resnet.model) - - total_iterations = 100 - - for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: - logger.info(f"Batch size: {batch_size}") - for iteration_number in range(total_iterations + int(batch_size==1)): - logger.info(f"Iteration: {iteration_number}") - client.run_model(resnet.name, resnet.get_batch(batch_size)) - - client.print_timings(to_file=True) \ No newline at end of file diff --git a/ex/high_throughput_inference/mock_app_redis.py b/ex/high_throughput_inference/mock_app_redis.py deleted file mode 100644 index c56b4fb8b4..0000000000 --- a/ex/high_throughput_inference/mock_app_redis.py +++ /dev/null @@ -1,88 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import argparse -import io -import numpy -import time -import torch -from smartsim.log import get_logger -from smartredis import Client - -logger = get_logger("App") - -class ResNetWrapper(): - def __init__(self, name: str, model: str): - self._model = torch.jit.load(model) - self._name = name - buffer = io.BytesIO() - scripted = torch.jit.trace(self._model, self.get_batch()) - torch.jit.save(scripted, buffer) - self._serialized_model = buffer.getvalue() - - def get_batch(self, batch_size: int=32): - return torch.randn((batch_size, 3, 224, 224), dtype=torch.float32) - - @property - def model(self): - return self._serialized_model - - @property - def name(self): - return self._name - -if __name__ == "__main__": - - parser = argparse.ArgumentParser("Mock application") - parser.add_argument("--device", default="cpu") - args = parser.parse_args() - - resnet = ResNetWrapper("resnet50", f"resnet50.{args.device.upper()}.pt") - - client = Client(cluster=False, address=None) - client.set_model(resnet.name, resnet.model, backend='TORCH', device=args.device.upper()) - - total_iterations = 100 - timings=[] - for batch_size in [1, 2, 4, 8, 16, 32, 64, 128]: - logger.info(f"Batch size: {batch_size}") - for iteration_number in range(total_iterations + int(batch_size==1)): - timing = [batch_size] - logger.info(f"Iteration: {iteration_number}") - start = time.perf_counter() - client.put_tensor(name="batch", data=resnet.get_batch(batch_size).numpy()) - client.run_model(name=resnet.name, inputs=["batch"], outputs=["result"]) - result = client.get_tensor(name="result") - end = time.perf_counter() - timing.append(end-start) - timings.append(timing) - - - - timings_np = numpy.asarray(timings) - numpy.save("timings.npy", timings_np) - for timing in timings: - print(" ".join(str(t) for t in timing)) diff --git a/ex/high_throughput_inference/redis_driver.py b/ex/high_throughput_inference/redis_driver.py deleted file mode 100644 index ceddba4ef7..0000000000 --- a/ex/high_throughput_inference/redis_driver.py +++ /dev/null @@ -1,65 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import os -import sys -from smartsim import Experiment -from smartsim.status import TERMINAL_STATUSES -import time -import typing as t - -device = "gpu" -filedir = os.path.dirname(__file__) -app_script_name = os.path.join(filedir, "mock_app_redis.py") -model_name = os.path.join(filedir, f"resnet50.{device.upper()}.pt") - - -exp_path = os.path.join(filedir, "redis_ai") -os.makedirs(exp_path, exist_ok=True) -exp = Experiment("redis_ai", launcher="slurm", exp_path=exp_path) - -db = exp.create_database(interface="hsn0") - -app_rs = exp.create_run_settings(sys.executable, exe_args = [app_script_name, "--device", device]) -app_rs.set_nodes(1) -app_rs.set_tasks(1) -app = exp.create_model("app", run_settings=app_rs) -app.attach_generator_files(to_copy=[app_script_name], to_symlink=[model_name]) - -exp.generate(db, app, overwrite=True) - -exp.start(db, app, block=False) - -while True: - if exp.get_status(app)[0] in TERMINAL_STATUSES: - exp.stop(db) - break - if exp.get_status(db)[0] in TERMINAL_STATUSES: - exp.stop(app) - break - time.sleep(5) - -print("Exiting.") \ No newline at end of file diff --git a/ex/high_throughput_inference/standalone_workermanager.py b/ex/high_throughput_inference/standalone_workermanager.py deleted file mode 100644 index c56e11a7c3..0000000000 --- a/ex/high_throughput_inference/standalone_workermanager.py +++ /dev/null @@ -1,96 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -# isort: off -import dragon -from dragon import fli -from dragon.channels import Channel -from dragon.data.ddict.ddict import DDict -from dragon.utils import b64decode, b64encode -from dragon.globalservices.api_setup import connect_to_infrastructure -# isort: on -import argparse -import base64 -import cloudpickle -import pickle -import os - -from smartsim._core.mli.comm.channel.dragonchannel import DragonCommChannel -from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import DragonFeatureStore -from smartsim._core.mli.comm.channel.dragonfli import DragonFLIChannel -from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker -from smartsim._core.mli.infrastructure.control.workermanager import WorkerManager -from smartsim._core.mli.infrastructure.environmentloader import EnvironmentConfigLoader - - -if __name__ == "__main__": - parser = argparse.ArgumentParser("Worker Manager") - parser.add_argument( - "--device", - type=str, - default="gpu", - choices="gpu cpu".split(), - help="Device on which the inference takes place", - ) - parser.add_argument( - "--worker_class", - type=str, - required=True, - help="Serialized class of worker to run", - ) - parser.add_argument( - "--num_workers", type=int, default=1, help="Number of workers to run" - ) - - args = parser.parse_args() - connect_to_infrastructure() - ddict_str = os.environ["SS_DRG_DDICT"] - ddict = DDict.attach(ddict_str) - - to_worker_channel = Channel.make_process_local() - to_worker_fli = fli.FLInterface(main_ch=to_worker_channel, manager_ch=None) - to_worker_fli_serialized = to_worker_fli.serialize() - ddict["to_worker_fli"] = to_worker_fli_serialized - - torch_worker = cloudpickle.loads(base64.b64decode(args.worker_class.encode('ascii')))() - - dfs = DragonFeatureStore(ddict) - comm_channel = DragonFLIChannel(to_worker_fli_serialized) - - os.environ["SSFeatureStore"] = base64.b64encode(pickle.dumps(dfs)).decode("utf-8") - os.environ["SSQueue"] = base64.b64encode(to_worker_fli_serialized).decode("utf-8") - - config_loader = EnvironmentConfigLoader() - - worker_manager = WorkerManager( - config_loader=config_loader, - worker=torch_worker, - as_service=True, - cooldown=10, - comm_channel_type=DragonCommChannel, - device = args.device, - ) - worker_manager.execute() diff --git a/smartsim/_core/entrypoints/service.py b/smartsim/_core/entrypoints/service.py index df9c2bbef6..e03df6bea1 100644 --- a/smartsim/_core/entrypoints/service.py +++ b/smartsim/_core/entrypoints/service.py @@ -46,8 +46,7 @@ def __init__( :param as_service: Determines if the host will run until shutdown criteria are met or as a run-once instance :param cooldown: Period of time to allow service to run before automatic - shutdown, in seconds. A non-zero, positive integer. - :param loop_delay: delay between iterations of the event loop""" + shutdown, in seconds. A non-zero, positive integer.""" self._as_service = as_service """If the service should run until shutdown function returns True""" self._cooldown = abs(cooldown) @@ -103,23 +102,6 @@ def execute(self) -> None: running = True cooldown_start: t.Optional[datetime.datetime] = None - headers = [ - "batch_size", - "w_deserialize", - "w_fetch_model", - "w_load_model", - "w_fetch_input", - "w_transform_input", - "w_execute", - "w_transform_output", - "w_assign_output", - "w_build_reply", - "w_serialize_resp", - "w_send", - ] - - print(",".join(headers)) - while running: self._on_iteration() diff --git a/smartsim/_core/launcher/dragon/dragonBackend.py b/smartsim/_core/launcher/dragon/dragonBackend.py index dcc5c8392b..2456606623 100644 --- a/smartsim/_core/launcher/dragon/dragonBackend.py +++ b/smartsim/_core/launcher/dragon/dragonBackend.py @@ -36,10 +36,8 @@ # pylint: disable=import-error # isort: off -import dragon.data.ddict.ddict as dragon_ddict import dragon.infrastructure.connection as dragon_connection import dragon.infrastructure.policy as dragon_policy -import dragon.infrastructure.process_desc as dragon_process_desc import dragon.native.group_state as dragon_group_state import dragon.native.process as dragon_process import dragon.native.process_group as dragon_process_group @@ -189,7 +187,6 @@ def __init__(self, pid: int) -> None: self._view = DragonBackendView(self) logger.debug(self._view.host_desc) - self._infra_ddict: t.Optional[dragon_ddict.DDict] = None @property def hosts(self) -> list[str]: @@ -394,22 +391,6 @@ def _stop_steps(self) -> None: self._group_infos[step_id].status = SmartSimStatus.STATUS_CANCELLED self._group_infos[step_id].return_codes = [-9] - @property - def infra_ddict(self) -> str: - """Create a Dragon distributed dictionary and return its - serialized descriptor - """ - if self._infra_ddict is None: - logger.info("Creating DDict") - self._infra_ddict = dragon_ddict.DDict( - n_nodes=len(self._hosts), total_mem=len(self._hosts) * 1024**3 - ) # todo: parametrize - logger.info("Created DDict") - self._infra_ddict["creation"] = str(time.time()) - logger.info(self._infra_ddict["creation"]) - - return str(self._infra_ddict.serialize()) - def _start_steps(self) -> None: self._heartbeat() with self._queue_lock: @@ -425,7 +406,6 @@ def _start_steps(self) -> None: placement=dragon_policy.Policy.Placement.HOST_NAME, host_name=hosts[0], ) - options = dragon_process_desc.ProcessOptions(make_inf_channels=True) grp = dragon_process_group.ProcessGroup( restart=False, pmi_enabled=request.pmi_enabled, policy=global_policy ) @@ -441,15 +421,10 @@ def _start_steps(self) -> None: target=request.exe, args=request.exe_args, cwd=request.path, - env={ - **request.current_env, - **request.env, - "SS_DRG_DDICT": self.infra_ddict, - }, + env={**request.current_env, **request.env}, stdout=dragon_process.Popen.PIPE, stderr=dragon_process.Popen.PIPE, policy=local_policy, - options=options, ) grp.add_process(nproc=request.tasks_per_node, template=tmp_proc) diff --git a/smartsim/_core/mli/comm/channel/channel.py b/smartsim/_core/mli/comm/channel/channel.py index 2318896a9b..201ab9deab 100644 --- a/smartsim/_core/mli/comm/channel/channel.py +++ b/smartsim/_core/mli/comm/channel/channel.py @@ -41,14 +41,9 @@ def __init__(self, descriptor: t.Union[str, bytes]) -> None: @abstractmethod def send(self, value: bytes) -> None: - """Send a message through the underlying communication channel + """Send a message throuh the underlying communication channel :param value: The value to send""" - @abstractmethod - def recv(self) -> bytes: - """Receieve a message through the underlying communication channel - :returns: the received message""" - @property def descriptor(self) -> bytes: """Return the channel descriptor for the underlying dragon channel""" diff --git a/smartsim/_core/mli/comm/channel/dragonchannel.py b/smartsim/_core/mli/comm/channel/dragonchannel.py index 1409747a91..4fd26861ca 100644 --- a/smartsim/_core/mli/comm/channel/dragonchannel.py +++ b/smartsim/_core/mli/comm/channel/dragonchannel.py @@ -24,18 +24,16 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import sys +import typing as t import smartsim._core.mli.comm.channel.channel as cch from smartsim.log import get_logger logger = get_logger(__name__) -try: +if t.TYPE_CHECKING: import dragon.channels as dch -except ImportError as exc: - if not "pytest" in sys.modules: - raise exc from None + import dragon.utils as du class DragonCommChannel(cch.CommChannelBase): @@ -44,17 +42,11 @@ class DragonCommChannel(cch.CommChannelBase): def __init__(self, key: bytes) -> None: """Initialize the DragonCommChannel instance""" super().__init__(key) - self._channel: dch.Channel = dch.Channel.attach(key) + # todo: do we need memory pool information to construct the channel correctly? + self._channel: "dch.Channel" = du.get_channel(key) def send(self, value: bytes) -> None: """Send a message throuh the underlying communication channel :param value: The value to send""" - with self._channel.sendh(timeout=None) as sendh: - sendh.send_bytes(value) - - def recv(self) -> bytes: - """Receieve a message through the underlying communication channel - :returns: the received message""" - with self._channel.recvh(timeout=None) as recvh: - message_bytes: bytes = recvh.recv_bytes(timeout=None) - return message_bytes + logger.debug(f"Channel {self.descriptor.decode('utf-8')} sending message") + self._channel.send_bytes(value) diff --git a/smartsim/_core/mli/comm/channel/dragonfli.py b/smartsim/_core/mli/comm/channel/dragonfli.py deleted file mode 100644 index 75f8fb4bfc..0000000000 --- a/smartsim/_core/mli/comm/channel/dragonfli.py +++ /dev/null @@ -1,69 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -# isort: off -from dragon import fli -import dragon.channels as dch - -# isort: on - -import sys -import typing as t - -import smartsim._core.mli.comm.channel.channel as cch -from smartsim.log import get_logger - -logger = get_logger(__name__) - - -class DragonFLIChannel(cch.CommChannelBase): - """Passes messages by writing to a Dragon FLI Channel""" - - def __init__(self, fli_desc: bytes, sender_supplied: bool = True) -> None: - """Initialize the DragonFLIChannel instance""" - super().__init__(fli_desc) - # todo: do we need memory pool information to construct the channel correctly? - self._fli: "fli" = fli.FLInterface.attach(fli_desc) - self._channel: t.Optional["dch"] = ( - dch.Channel.make_process_local() if sender_supplied else None - ) - - def send(self, value: bytes) -> None: - """Send a message through the underlying communication channel - :param value: The value to send""" - with self._fli.sendh(timeout=None, stream_channel=self._channel) as sendh: - sendh.send_bytes(value) - - def recv(self) -> bytes: - """Receieve a message through the underlying communication channel - :returns: the received message""" - with self._fli.recvh(timeout=None) as recvh: - try: - request_bytes: bytes - request_bytes, _ = recvh.recv_bytes(timeout=None) - return request_bytes - except fli.FLIEOT as exc: - return b"" diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index c3950fa987..fc87fb1431 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -24,34 +24,24 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import sys - -# isort: off -import dragon -from dragon import fli - -# isort: on - -import time +import multiprocessing as mp import typing as t import numpy as np -from .....error import SmartSimError -from .....log import get_logger -from ....entrypoints.service import Service -from ...comm.channel.channel import CommChannelBase -from ...comm.channel.dragonchannel import DragonCommChannel -from ...infrastructure.environmentloader import EnvironmentConfigLoader -from ...infrastructure.storage.featurestore import FeatureStore -from ...infrastructure.worker.worker import ( +from smartsim._core.entrypoints.service import Service +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim._core.mli.comm.channel.dragonchannel import DragonCommChannel +from smartsim._core.mli.infrastructure.environmentloader import EnvironmentConfigLoader +from smartsim._core.mli.infrastructure.storage.featurestore import FeatureStore +from smartsim._core.mli.infrastructure.worker.worker import ( InferenceReply, InferenceRequest, - LoadModelResult, MachineLearningWorkerBase, ) -from ...message_handler import MessageHandler -from ...mli_schemas.response.response_capnp import Response +from smartsim._core.mli.message_handler import MessageHandler +from smartsim._core.mli.mli_schemas.response.response_capnp import Response +from smartsim.log import get_logger if t.TYPE_CHECKING: from dragon.fli import FLInterface @@ -63,9 +53,7 @@ def deserialize_message( - data_blob: bytes, - channel_type: t.Type[CommChannelBase], - device: t.Literal["cpu", "gpu"], + data_blob: bytes, channel_type: t.Type[CommChannelBase] ) -> InferenceRequest: """Deserialize a message from a byte stream into an InferenceRequest :param data_blob: The byte stream to deserialize""" @@ -100,6 +88,12 @@ def deserialize_message( ) output_keys: t.Optional[t.List[str]] = None + # # client example + # msg = Message() + # t = torch.Tensor() + # msg.inputs = [custom_byte_converter(t)] + # mli_client.request_inference(msg) + # # end client input_meta: t.List[t.Any] = [] if request.input.which() == "keys": @@ -209,7 +203,6 @@ def __init__( as_service: bool = False, cooldown: int = 0, comm_channel_type: t.Type[CommChannelBase] = DragonCommChannel, - device: t.Literal["cpu", "gpu"] = "cpu", ) -> None: """Initialize the WorkerManager :param config_loader: Environment config loader that loads the task queue and @@ -222,7 +215,8 @@ def __init__( """ super().__init__(as_service, cooldown) - self._task_queue: t.Optional[CommChannelBase] = config_loader.get_queue() + """a collection of workers the manager is controlling""" + self._task_queue: t.Optional["FLInterface"] = config_loader.get_queue() """the queue the manager monitors for new tasks""" self._feature_store: t.Optional[FeatureStore] = ( config_loader.get_feature_store() @@ -232,10 +226,6 @@ def __init__( """The ML Worker implementation""" self._comm_channel_type = comm_channel_type """The type of communication channel to construct for callbacks""" - self._device = device - """Device on which workers need to run""" - self._cached_models: dict[str, t.Any] = {} - """Dictionary of previously loaded models""" def _validate_request(self, request: InferenceRequest) -> bool: """Ensure the request can be processed. @@ -277,112 +267,42 @@ def _on_iteration(self) -> None: logger.warning("No queue to check for tasks") return - timings = [] # timing # perform default deserialization of the message envelope - request_bytes: bytes = self._task_queue.recv() + receiver = self._task_queue.recvh(use_main_as_stream_channel=True) + request_bytes, _ = receiver.recv_bytes() - interm = time.perf_counter() # timing - request = deserialize_message( - request_bytes, self._comm_channel_type, self._device - ) + request = deserialize_message(request_bytes, self._comm_channel_type) if not self._validate_request(request): return - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - reply = InferenceReply() - if not request.raw_model: - if request.model_key is None: - response = build_failure_reply("fail", "Could not read model key.") - serialized_resp = MessageHandler.serialize_response(response) # type: ignore - if request.callback: - request.callback.send(serialized_resp) - return - if request.model_key in self._cached_models: - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - model_result = LoadModelResult(self._cached_models[request.model_key]) - - else: - fetch_model_result = None - while True: - try: - interm = time.perf_counter() # timing - fetch_model_result = self._worker.fetch_model( - request, self._feature_store - ) - except KeyError: - time.sleep(0.1) - except Exception as e: - exception_handler( - e, request.callback, "fetching the model", reply - ) - break - return - - if fetch_model_result is None: - response = build_failure_reply( - "fail", "Could not retrieve model from feature store." - ) - serialized_resp = MessageHandler.serialize_response(response) # type: ignore - if request.callback: - request.callback.send(serialized_resp) - return - else: - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - try: - model_result = self._worker.load_model( - request, - fetch_result=fetch_model_result, - device=self._device, - ) - self._cached_models[request.model_key] = model_result.model - except Exception as e: - exception_handler( - e, request.callback, "loading the model", reply - ) - return + try: + fetch_model_result = self._worker.fetch_model(request, self._feature_store) + except Exception as e: + exception_handler(e, request.callback, "fetching the model", reply) + return - else: - try: - fetch_model_result = self._worker.fetch_model( - request, self._feature_store - ) - except Exception as e: - exception_handler(e, request.callback, "fetching the model", reply) - return - try: - model_result = self._worker.load_model( - request, fetch_result=fetch_model_result, device=self._device - ) - except Exception as e: - exception_handler(e, request.callback, "loading the model", reply) - return + try: + model_result = self._worker.load_model(request, fetch_model_result) + except Exception as e: + exception_handler(e, request.callback, "loading the model", reply) + return - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing try: fetch_input_result = self._worker.fetch_inputs(request, self._feature_store) except Exception as e: exception_handler(e, request.callback, "fetching the inputs", reply) return - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing try: transformed_input = self._worker.transform_input( - request, fetch_input_result, self._device + request, fetch_input_result ) except Exception as e: exception_handler(e, request.callback, "transforming the input", reply) return - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - try: execute_result = self._worker.execute( request, model_result, transformed_input @@ -391,18 +311,12 @@ def _on_iteration(self) -> None: exception_handler(e, request.callback, "executing", reply) return - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing try: - transformed_output = self._worker.transform_output( - request, execute_result, self._device - ) + transformed_output = self._worker.transform_output(request, execute_result) except Exception as e: exception_handler(e, request.callback, "transforming the output", reply) return - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing if request.output_keys: try: reply.output_keys = self._worker.place_output( @@ -416,31 +330,18 @@ def _on_iteration(self) -> None: else: reply.outputs = transformed_output.outputs - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - if reply.outputs is None or not reply.outputs: - response = build_failure_reply("fail", "no-results") + response = build_failure_reply("fail", "Outputs not found.") + else: reply.status_enum = "complete" reply.message = "Success" response = build_reply(reply) - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - serialized_resp = MessageHandler.serialize_response(response) # type: ignore - - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing if request.callback: request.callback.send(serialized_resp) - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - - print(" ".join(str(time) for time in timings)) # timing - def _can_shutdown(self) -> bool: """Return true when the criteria to shut down the service are met.""" # todo: determine shutdown criteria diff --git a/smartsim/_core/mli/infrastructure/environmentloader.py b/smartsim/_core/mli/infrastructure/environmentloader.py index 9f6770623d..267b668f63 100644 --- a/smartsim/_core/mli/infrastructure/environmentloader.py +++ b/smartsim/_core/mli/infrastructure/environmentloader.py @@ -31,7 +31,6 @@ from dragon.fli import FLInterface # pylint: disable=all -from smartsim._core.mli.comm.channel.dragonfli import DragonFLIChannel from smartsim._core.mli.infrastructure.storage.featurestore import FeatureStore @@ -42,12 +41,10 @@ class EnvironmentConfigLoader: """ def __init__(self) -> None: - self._feature_store_descriptor: t.Optional[str] = os.getenv( - "SSFeatureStore", None - ) - self._queue_descriptor: t.Optional[str] = os.getenv("SSQueue", None) + self._feature_store_descriptor = os.getenv("SSFeatureStore", None) + self._queue_descriptor = os.getenv("SSQueue", None) self.feature_store: t.Optional[FeatureStore] = None - self.queue: t.Optional[DragonFLIChannel] = None + self.queue: t.Optional["FLInterface"] = None def get_feature_store(self) -> t.Optional[FeatureStore]: """Loads the Feature Store previously set in SSFeatureStore""" @@ -57,11 +54,8 @@ def get_feature_store(self) -> t.Optional[FeatureStore]: ) return self.feature_store - def get_queue(self, sender_supplied: bool = True) -> t.Optional[DragonFLIChannel]: + def get_queue(self) -> t.Optional["FLInterface"]: """Returns the Queue previously set in SSQueue""" if self._queue_descriptor is not None: - self.queue = DragonFLIChannel( - fli_desc=base64.b64decode(self._queue_descriptor), - sender_supplied=sender_supplied, - ) + self.queue = FLInterface.attach(base64.b64decode(self._queue_descriptor)) return self.queue diff --git a/smartsim/_core/mli/infrastructure/storage/dragonfeaturestore.py b/smartsim/_core/mli/infrastructure/storage/dragonfeaturestore.py index af592ed0ab..8153255d0a 100644 --- a/smartsim/_core/mli/infrastructure/storage/dragonfeaturestore.py +++ b/smartsim/_core/mli/infrastructure/storage/dragonfeaturestore.py @@ -44,28 +44,27 @@ def __init__(self, storage: "DDict") -> None: """Initialize the DragonFeatureStore instance""" self._storage = storage - def __getitem__(self, key: str) -> t.Union[str, bytes]: + def __getitem__(self, key: str) -> t.Any: """Retrieve an item using key :param key: Unique key of an item to retrieve from the feature store""" + key_ = key.encode("utf-8") try: - value: t.Union[str, bytes] = self._storage[key] - return value - except KeyError as ex: - raise ex + return self._storage[key_] except Exception as ex: # note: explicitly avoid round-trip to check for key existence - raise sse.SmartSimError( - f"Could not get value for existing key {key}, error:\n{ex}" - ) from ex + raise sse.SmartSimError(f"{key} not found in feature store") from ex - def __setitem__(self, key: str, value: t.Union[str, bytes]) -> None: + def __setitem__(self, key: str, value: bytes) -> None: """Assign a value using key :param key: Unique key of an item to set in the feature store :param value: Value to persist in the feature store""" - self._storage[key] = value + key_ = key.encode("utf-8") + self._storage[key_] = value - def __contains__(self, key: str) -> bool: + def __contains__(self, key: t.Union[str, bytes]) -> bool: """Membership operator to test for a key existing within the feature store. Return `True` if the key is found, `False` otherwise :param key: Unique key of an item to retrieve from the feature store""" + if isinstance(key, str): + key = key.encode("utf-8") return key in self._storage diff --git a/smartsim/_core/mli/infrastructure/storage/featurestore.py b/smartsim/_core/mli/infrastructure/storage/featurestore.py index 553e13b10f..ec4086b732 100644 --- a/smartsim/_core/mli/infrastructure/storage/featurestore.py +++ b/smartsim/_core/mli/infrastructure/storage/featurestore.py @@ -24,7 +24,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import typing as t from abc import ABC, abstractmethod @@ -33,12 +32,12 @@ class FeatureStore(ABC): values from a feature store implementation""" @abstractmethod - def __getitem__(self, key: str) -> t.Union[str, bytes]: + def __getitem__(self, key: str) -> bytes: """Retrieve an item using key :param key: Unique key of an item to retrieve from the feature store""" @abstractmethod - def __setitem__(self, key: str, value: t.Union[str, bytes]) -> None: + def __setitem__(self, key: str, value: bytes) -> None: """Assign a value using key :param key: Unique key of an item to set in the feature store :param value: Value to persist in the feature store""" diff --git a/smartsim/_core/mli/infrastructure/worker/torch_worker.py b/smartsim/_core/mli/infrastructure/worker/torch_worker.py deleted file mode 100644 index a4e725ab99..0000000000 --- a/smartsim/_core/mli/infrastructure/worker/torch_worker.py +++ /dev/null @@ -1,119 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import io - -import numpy as np -import torch - -from .....error import SmartSimError -from .....log import get_logger -from ...mli_schemas.tensor import tensor_capnp -from .worker import ( - ExecuteResult, - FetchInputResult, - FetchModelResult, - InferenceRequest, - LoadModelResult, - MachineLearningWorkerBase, - TransformInputResult, - TransformOutputResult, -) - -logger = get_logger(__name__) - - -class TorchWorker(MachineLearningWorkerBase): - """A worker that executes a PyTorch model.""" - - @staticmethod - def load_model( - request: InferenceRequest, fetch_result: FetchModelResult, device: str - ) -> LoadModelResult: - if fetch_result.model_bytes: - model_bytes = fetch_result.model_bytes - elif request.raw_model and request.raw_model.data: - model_bytes = request.raw_model.data - else: - raise ValueError("Unable to load model without reference object") - - device_to_torch = {"cpu": "cpu", "gpu": "cuda"} - device = device_to_torch[device] - buffer = io.BytesIO(initial_bytes=model_bytes) - model = torch.jit.load(buffer, map_location=device) # type: ignore - result = LoadModelResult(model) - return result - - @staticmethod - def transform_input( - request: InferenceRequest, fetch_result: FetchInputResult, device: str - ) -> TransformInputResult: - result = [] - - device_to_torch = {"cpu": "cpu", "gpu": "cuda"} - device = device_to_torch[device] - if fetch_result.meta is None: - raise ValueError("Cannot reconstruct tensor without meta information") - for item, item_meta in zip(fetch_result.inputs, fetch_result.meta): - tensor_desc: tensor_capnp.TensorDescriptor = item_meta - result.append( - torch.tensor(np.frombuffer(item, dtype=str(tensor_desc.dataType))) - .to(device) - .reshape(tuple(dim for dim in tensor_desc.dimensions)) - ) - return TransformInputResult(result) - # return data # note: this fails copy test! - - @staticmethod - def execute( - request: InferenceRequest, - load_result: LoadModelResult, - transform_result: TransformInputResult, - ) -> ExecuteResult: - if not load_result.model: - raise SmartSimError("Model must be loaded to execute") - - model: torch.nn.Module = load_result.model - model.eval() - results = [model(tensor).detach() for tensor in transform_result.transformed] - - execute_result = ExecuteResult(results) - return execute_result - - @staticmethod - def transform_output( - request: InferenceRequest, - execute_result: ExecuteResult, - result_device: str, - ) -> TransformOutputResult: - if result_device != "cpu": - transformed = [item.to("cpu") for item in execute_result.predictions] - # todo: need the shape from latest schemas added here. - return TransformOutputResult(transformed, None, "c", "float32") # fixme - - return TransformOutputResult( - execute_result.predictions, None, "c", "float32" - ) # fixme diff --git a/smartsim/_core/mli/infrastructure/worker/worker.py b/smartsim/_core/mli/infrastructure/worker/worker.py index 3d3b36fbdf..77781f2286 100644 --- a/smartsim/_core/mli/infrastructure/worker/worker.py +++ b/smartsim/_core/mli/infrastructure/worker/worker.py @@ -27,11 +27,11 @@ import typing as t from abc import ABC, abstractmethod -from .....error import SmartSimError -from .....log import get_logger -from ...comm.channel.channel import CommChannelBase -from ...infrastructure.storage.featurestore import FeatureStore -from ...mli_schemas.model.model_capnp import Model +import smartsim.error as sse +from smartsim._core.mli.comm.channel.channel import CommChannelBase +from smartsim._core.mli.infrastructure.storage.featurestore import FeatureStore +from smartsim._core.mli.mli_schemas.model.model_capnp import Model +from smartsim.log import get_logger if t.TYPE_CHECKING: from smartsim._core.mli.mli_schemas.response.response_capnp import StatusEnum @@ -110,23 +110,23 @@ def __init__(self, result: t.Any) -> None: class FetchInputResult: """A wrapper around fetched inputs""" - def __init__(self, result: t.List[bytes], meta: t.Optional[t.List[t.Any]]) -> None: + def __init__(self, result: t.List[bytes]) -> None: """Initialize the object""" self.inputs = result - self.meta = meta class TransformOutputResult: """A wrapper around inference results transformed for transmission""" def __init__( - self, result: t.Any, shape: t.Optional[t.List[int]], order: str, dtype: str + self, result: t.Any, shape: t.List[int], order: str, dtype: str ) -> None: """Initialize the OutputTransformResult""" self.outputs = result self.shape = shape self.order = order self.dtype = dtype + # todo: determine if each output must have an individual (shape, order, dtype) class CreateInputBatchResult: @@ -142,7 +142,7 @@ class FetchModelResult: def __init__(self, result: bytes) -> None: """Initialize the object""" - self.model_bytes: bytes = result + self.model_bytes = result class MachineLearningWorkerCore: @@ -156,6 +156,8 @@ def fetch_model( :param request: The request that triggered the pipeline :param feature_store: The feature store used for persistence :return: Raw bytes of the model""" + if not feature_store: + raise ValueError("Feature store is required for model retrieval") if request.raw_model: # Should we cache model in the feature store? @@ -164,20 +166,17 @@ def fetch_model( # short-circuit and return the directly supplied model return FetchModelResult(request.raw_model.data) - if not feature_store: - raise ValueError("Feature store is required for model retrieval") - if not request.model_key: - raise SmartSimError( + raise sse.SmartSimError( "Key must be provided to retrieve model from feature store" ) try: - raw_bytes: bytes = t.cast(bytes, feature_store[request.model_key]) + raw_bytes = feature_store[request.model_key] return FetchModelResult(raw_bytes) except FileNotFoundError as ex: logger.exception(ex) - raise SmartSimError( + raise sse.SmartSimError( f"Model could not be retrieved with key {request.model_key}" ) from ex @@ -190,27 +189,24 @@ def fetch_inputs( :param request: The request that triggered the pipeline :param feature_store: The feature store used for persistence :return: the fetched input""" - - if request.raw_inputs: - return FetchInputResult(request.raw_inputs, request.input_meta) - if not feature_store: - raise ValueError("No input and no feature store provided") + raise ValueError("Feature store is required for input retrieval") if request.input_keys: data: t.List[bytes] = [] for input_ in request.input_keys: try: - tensor_bytes = t.cast(bytes, feature_store[input_]) + tensor_bytes = feature_store[input_] data.append(tensor_bytes) except KeyError as ex: logger.exception(ex) - raise SmartSimError( + raise sse.SmartSimError( f"Model could not be retrieved with key {input_}" ) from ex - return FetchInputResult( - data, None - ) # fixme: need to get both tensor and descriptor + return FetchInputResult(data) + + if request.raw_inputs: + return FetchInputResult(request.raw_inputs) raise ValueError("No input source") @@ -258,26 +254,32 @@ class MachineLearningWorkerBase(MachineLearningWorkerCore, ABC): """Abstrct base class providing contract for a machine learning worker implementation.""" + # @staticmethod + # @abstractmethod + # def deserialize(request: InferenceRequest) -> InferenceRequest: + # """Given a collection of data serialized to bytes, convert the bytes + # to a proper representation used by the ML backend + # :param data_blob: inference request as a byte-serialized blob + # :return: InferenceRequest deserialized from the input""" + @staticmethod @abstractmethod def load_model( - request: InferenceRequest, fetch_result: FetchModelResult, device: str + request: InferenceRequest, fetch_result: FetchModelResult ) -> LoadModelResult: """Given a loaded MachineLearningModel, ensure it is loaded into device memory :param request: The request that triggered the pipeline - :param device: The device on which the model must be placed :return: ModelLoadResult wrapping the model loaded for the request""" @staticmethod @abstractmethod def transform_input( - request: InferenceRequest, fetch_result: FetchInputResult, device: str + request: InferenceRequest, fetch_result: FetchInputResult ) -> TransformInputResult: """Given a collection of data, perform a transformation on the data :param request: The request that triggered the pipeline :param fetch_result: Raw output from fetching inputs out of a feature store - :param device: The device on which the transformed input must be placed :return: The transformed inputs wrapped in a InputTransformResult""" @staticmethod @@ -296,11 +298,20 @@ def execute( @staticmethod @abstractmethod def transform_output( - request: InferenceRequest, execute_result: ExecuteResult, result_device: str + request: InferenceRequest, + execute_result: ExecuteResult, ) -> TransformOutputResult: """Given inference results, perform transformations required to transmit results to the requestor. :param request: The request that triggered the pipeline :param execute_result: The result of inference wrapped in an ExecuteResult - :param result_device: The device on which the result of inference is placed :return:""" + + # @staticmethod + # @abstractmethod + # def serialize_reply( + # request: InferenceRequest, results: OutputTransformResult + # ) -> bytes: + # """Given an output, serialize to bytes for transport + # :param reply: The result of the inference pipeline + # :return: a byte-serialized version of the reply""" diff --git a/smartsim/_core/mli/message_handler.py b/smartsim/_core/mli/message_handler.py index bcf1cfdf14..16cb242b7c 100644 --- a/smartsim/_core/mli/message_handler.py +++ b/smartsim/_core/mli/message_handler.py @@ -396,9 +396,7 @@ def deserialize_request(request_bytes: bytes) -> request_capnp.Request: :param request_bytes: Bytes to be deserialized into a Request """ - bytes_message = request_capnp.Request.from_bytes( - request_bytes, traversal_limit_in_words=2**63 - ) + bytes_message = request_capnp.Request.from_bytes(request_bytes) with bytes_message as message: return message @@ -491,7 +489,7 @@ def _assign_custom_response_attributes( response.customAttributes.tf = custom_attrs # type: ignore else: raise ValueError("""Invalid custom attribute class name. - Expected 'TensorFlowResponseAttributes' or + Expected 'TensorFlowResponseAttributes' or 'TorchResponseAttributes'.""") except Exception as e: raise ValueError("Error assigning custom attributes to response.") from e @@ -536,9 +534,7 @@ def deserialize_response(response_bytes: bytes) -> response_capnp.Response: """ Deserializes a serialized response message. """ - bytes_message = response_capnp.Response.from_bytes( - response_bytes, traversal_limit_in_words=2**63 - ) + bytes_message = response_capnp.Response.from_bytes(response_bytes) with bytes_message as message: return message diff --git a/tests/dragon/test_environment_loader.py b/tests/dragon/test_environment_loader.py index 00db0a9d32..d339fec885 100644 --- a/tests/dragon/test_environment_loader.py +++ b/tests/dragon/test_environment_loader.py @@ -64,9 +64,10 @@ def test_environment_loader_attach_FLI(content, monkeypatch): config = EnvironmentConfigLoader() config_queue = config.get_queue() - new_sender = config_queue.send(content) + new_sender = config_queue.sendh(use_main_as_stream_channel=True) + new_sender.send_bytes(content) - old_recv = queue.recvh() + old_recv = queue.recvh(use_main_as_stream_channel=True) result, _ = old_recv.recv_bytes() assert result == content @@ -80,7 +81,7 @@ def test_environment_loader_serialize_FLI(monkeypatch): config = EnvironmentConfigLoader() config_queue = config.get_queue() - assert config_queue._fli.serialize() == queue.serialize() + assert config_queue.serialize() == queue.serialize() def test_environment_loader_FLI_fails(monkeypatch): diff --git a/tests/dragon/test_error_handling.py b/tests/dragon/test_error_handling.py index 2aa81e666f..4ced3e0ff2 100644 --- a/tests/dragon/test_error_handling.py +++ b/tests/dragon/test_error_handling.py @@ -89,7 +89,8 @@ def setup_worker_manager(test_dir, monkeypatch: pytest.MonkeyPatch): test_dir, model, [tensor_key], [tensor_key], [], None ) ser_request = MessageHandler.serialize_request(request) - new_sender = worker_manager._task_queue.send(ser_request) + new_sender = worker_manager._task_queue.sendh(use_main_as_stream_channel=True) + new_sender.send_bytes(ser_request) return worker_manager, integrated_worker @@ -162,7 +163,7 @@ def test_pipeline_stage_errors_handled( monkeypatch.setattr( integrated_worker, "fetch_inputs", - MagicMock(return_value=FetchInputResult([b"result_bytes"], None)), + MagicMock(return_value=FetchInputResult([b"result_bytes"])), ) if stage not in ["fetch_model", "load_model", "fetch_inputs", "transform_input"]: monkeypatch.setattr( diff --git a/tests/dragon/utils/channel.py b/tests/dragon/utils/channel.py index 2523b1a6bf..4bc2014ea3 100644 --- a/tests/dragon/utils/channel.py +++ b/tests/dragon/utils/channel.py @@ -57,8 +57,3 @@ def send(self, value: bytes) -> None: f"Channel {self.descriptor.decode('utf-8')} sending message to {self._file_path}" ) self._file_path.write_bytes(value) - - def recv(self) -> bytes: - """Receive a message through the underlying communication channel - :returns: the received message""" - ... diff --git a/tests/mli/test_torch_worker.py b/tests/mli/test_torch_worker.py deleted file mode 100644 index 0b1cd4ccf3..0000000000 --- a/tests/mli/test_torch_worker.py +++ /dev/null @@ -1,173 +0,0 @@ -# BSD 2-Clause License -# -# Copyright (c) 2021-2024, Hewlett Packard Enterprise -# All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -import io - -import numpy as np -import pytest -import torch -from torch import nn -from torch.nn import functional as F - -from smartsim._core.mli.infrastructure.worker.torch_worker import TorchWorker -from smartsim._core.mli.infrastructure.worker.worker import ( - ExecuteResult, - FetchInputResult, - FetchModelResult, - InferenceRequest, - LoadModelResult, - TransformInputResult, -) -from smartsim._core.mli.message_handler import MessageHandler -from smartsim.log import get_logger - -logger = get_logger(__name__) -# The tests in this file belong to the group_a group -pytestmark = pytest.mark.group_a - - -# simple MNIST in PyTorch -class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 32, 3, 1) - self.conv2 = nn.Conv2d(32, 64, 3, 1) - self.dropout1 = nn.Dropout(0.25) - self.dropout2 = nn.Dropout(0.5) - self.fc1 = nn.Linear(9216, 128) - self.fc2 = nn.Linear(128, 10) - - def forward(self, x): - x = self.conv1(x) - x = F.relu(x) - x = self.conv2(x) - x = F.relu(x) - x = F.max_pool2d(x, 2) - x = self.dropout1(x) - x = torch.flatten(x, 1) - x = self.fc1(x) - x = F.relu(x) - x = self.dropout2(x) - x = self.fc2(x) - output = F.log_softmax(x, dim=1) - return output - - -torch_device = {"cpu": "cpu", "gpu": "cuda"} - - -def get_batch() -> torch.Tensor: - return torch.rand(20, 1, 28, 28) - - -def create_torch_model(): - n = Net() - example_forward_input = get_batch() - module = torch.jit.trace(n, example_forward_input) - model_buffer = io.BytesIO() - torch.jit.save(module, model_buffer) - return model_buffer.getvalue() - - -def get_request() -> InferenceRequest: - - tensors = [get_batch() for _ in range(2)] - serialized_tensors = [ - MessageHandler.build_tensor(tensor.numpy(), "c", "float32", list(tensor.shape)) - for tensor in tensors - ] - - return InferenceRequest( - model_key="model", - callback=None, - raw_inputs=[s_tensor.blob for s_tensor in serialized_tensors], - input_keys=None, - input_meta=[s_tensor.tensorDescriptor for s_tensor in serialized_tensors], - output_keys=None, - raw_model=create_torch_model(), - batch_size=0, - ) - - -sample_request: InferenceRequest = get_request() -worker = TorchWorker() - - -def test_load_model(mlutils) -> None: - fetch_model_result = FetchModelResult(sample_request.raw_model) - load_model_result = worker.load_model( - sample_request, fetch_model_result, mlutils.get_test_device().lower() - ) - - assert load_model_result.model( - get_batch().to(torch_device[mlutils.get_test_device().lower()]) - ).shape == torch.Size((20, 10)) - - -def test_transform_input(mlutils) -> None: - fetch_input_result = FetchInputResult( - sample_request.raw_inputs, sample_request.input_meta - ) - - transform_input_result = worker.transform_input( - sample_request, fetch_input_result, mlutils.get_test_device().lower() - ) - - assert all( - transformed.shape == get_batch().shape - for transformed in transform_input_result.transformed - ) - - -def test_execute(mlutils) -> None: - load_model_result = LoadModelResult( - Net().to(torch_device[mlutils.get_test_device().lower()]) - ) - transform_result = TransformInputResult( - [ - get_batch().to(torch_device[mlutils.get_test_device().lower()]) - for _ in range(2) - ] - ) - - execute_result = worker.execute(sample_request, load_model_result, transform_result) - - assert all( - result.shape == torch.Size((20, 10)) for result in execute_result.predictions - ) - - -def test_transform_output(mlutils): - execute_result = ExecuteResult([torch.rand((20, 10)) for _ in range(2)]) - - transformed_output = worker.transform_output( - sample_request, execute_result, torch_device[mlutils.get_test_device().lower()] - ) - - assert transformed_output.outputs == execute_result.predictions - assert transformed_output.shape == None - assert transformed_output.order == "c" - assert transformed_output.dtype == "float32" diff --git a/tests/mli/test_worker_manager.py b/tests/mli/test_worker_manager.py index df4b0a637f..7fd2c9aad4 100644 --- a/tests/mli/test_worker_manager.py +++ b/tests/mli/test_worker_manager.py @@ -29,10 +29,11 @@ import multiprocessing as mp import pathlib import time +import typing as t import pytest +import torch -torch = pytest.importorskip("torch") dragon = pytest.importorskip("dragon") from smartsim._core.mli.infrastructure.control.workermanager import ( diff --git a/tests/test_dragon_backend.py b/tests/test_dragon_backend.py index f284f38d99..a510f660a5 100644 --- a/tests/test_dragon_backend.py +++ b/tests/test_dragon_backend.py @@ -103,16 +103,6 @@ def get_mock_backend(monkeypatch: pytest.MonkeyPatch) -> "DragonBackend": "dragon.infrastructure.connection", MagicMock(), ) - monkeypatch.setitem( - sys.modules, - "dragon.infrastructure.process_desc", - MagicMock(), - ) - monkeypatch.setitem( - sys.modules, - "dragon.data.ddict.ddict", - MagicMock(), - ) monkeypatch.setitem( sys.modules, "dragon.infrastructure.policy", From ac312c2c73c44312db51c17c6cbd01e26eb1fe53 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Tue, 16 Jul 2024 12:12:26 -0500 Subject: [PATCH 44/57] merge fix --- .../infrastructure/control/workermanager.py | 117 ++++++++++-------- 1 file changed, 63 insertions(+), 54 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 7d7566e613..2ffe089418 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -280,8 +280,6 @@ def _on_iteration(self) -> None: timings = [] # timing # perform default deserialization of the message envelope request_bytes: bytes = self._task_queue.recv() - receiver = self._task_queue.recvh(use_main_as_stream_channel=True) - request_bytes, _ = receiver.recv_bytes() interm = time.perf_counter() # timing request = deserialize_message( @@ -293,10 +291,15 @@ def _on_iteration(self) -> None: timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing + reply = InferenceReply() + if not request.raw_model: if request.model_key is None: - # A valid request should never get here. - raise ValueError("Could not read model key") + response = build_failure_reply("fail", "Could not read model key.") + serialized_resp = MessageHandler.serialize_response(response) # type: ignore + if request.callback: + request.callback.send(serialized_resp) + return if request.model_key in self._cached_models: timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing @@ -312,72 +315,85 @@ def _on_iteration(self) -> None: ) except KeyError: time.sleep(0.1) - else: + except Exception as e: + exception_handler( + e, request.callback, "fetching the model", reply + ) break + return if fetch_model_result is None: - raise SmartSimError("Could not retrieve model from feature store") - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing + response = build_failure_reply( + "fail", "Could not retrieve model from feature store." + ) + serialized_resp = MessageHandler.serialize_response(response) # type: ignore + if request.callback: + request.callback.send(serialized_resp) + return + else: + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing + try: + model_result = self._worker.load_model( + request, + fetch_result=fetch_model_result, + device=self._device, + ) + self._cached_models[request.model_key] = model_result.model + except Exception as e: + exception_handler( + e, request.callback, "loading the model", reply + ) + return + + else: + try: + fetch_model_result = self._worker.fetch_model( + request, self._feature_store + ) + except Exception as e: + exception_handler(e, request.callback, "fetching the model", reply) + return + try: model_result = self._worker.load_model( - request, fetch_model_result, self._device + request, fetch_result=fetch_model_result, device=self._device ) - self._cached_models[request.model_key] = model_result.model - else: - fetch_model_result = self._worker.fetch_model(request, None) - model_result = self._worker.load_model( - request, fetch_result=fetch_model_result, device=self._device - ) - - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - fetch_input_result = self._worker.fetch_inputs(request, self._feature_store) - - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - transformed_input = self._worker.transform_input( - request, fetch_input_result, self._device - ) + except Exception as e: + exception_handler(e, request.callback, "loading the model", reply) + return timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing - - reply = InferenceReply() - - try: - fetch_model_result = self._worker.fetch_model(request, self._feature_store) - except Exception as e: - exception_handler(e, request.callback, "fetching the model", reply) - return - - try: - model_result = self._worker.load_model(request, fetch_model_result) - except Exception as e: - exception_handler(e, request.callback, "loading the model", reply) - return - try: fetch_input_result = self._worker.fetch_inputs(request, self._feature_store) except Exception as e: exception_handler(e, request.callback, "fetching the inputs", reply) return + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing try: transformed_input = self._worker.transform_input( - request, fetch_input_result + request, fetch_input_result, self._device ) except Exception as e: exception_handler(e, request.callback, "transforming the input", reply) return + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing + try: execute_result = self._worker.execute( request, model_result, transformed_input ) + except Exception as e: + exception_handler(e, request.callback, "executing", reply) + return + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing try: - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing transformed_output = self._worker.transform_output( request, execute_result, self._device ) @@ -385,8 +401,8 @@ def _on_iteration(self) -> None: exception_handler(e, request.callback, "transforming the output", reply) return - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing if request.output_keys: try: reply.output_keys = self._worker.place_output( @@ -404,14 +420,7 @@ def _on_iteration(self) -> None: interm = time.perf_counter() # timing if reply.outputs is None or not reply.outputs: - response = build_failure_reply("fail", "Outputs not found.") - - else: - reply.outputs = transformed_output.outputs - - if reply.outputs is None or not reply.outputs: - response = build_failure_reply("fail", "Outputs not found.") - + response = build_failure_reply("fail", "no-results") else: reply.status_enum = "complete" reply.message = "Success" @@ -441,4 +450,4 @@ def _can_shutdown(self) -> bool: # if time_diff.total_seconds() > self._cooldown: # return True # return False - return self._worker is None + return self._worker is None \ No newline at end of file From 365e7dc5e774ae00b2a447afcab3098fb38fecb9 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Tue, 16 Jul 2024 12:19:06 -0500 Subject: [PATCH 45/57] StatusEnum -> Status --- .../_core/mli/infrastructure/control/workermanager.py | 6 +++--- smartsim/_core/mli/infrastructure/worker/worker.py | 4 ++-- smartsim/_core/mli/message_handler.py | 4 ++-- smartsim/_core/mli/mli_schemas/response/response.capnp | 4 ++-- .../_core/mli/mli_schemas/response/response_capnp.pyi | 4 ++-- tests/dragon/test_reply_building.py | 10 +++++----- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 2ffe089418..5c61fcc0f8 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -57,7 +57,7 @@ from dragon.fli import FLInterface from smartsim._core.mli.mli_schemas.model.model_capnp import Model - from smartsim._core.mli.mli_schemas.response.response_capnp import StatusEnum + from smartsim._core.mli.mli_schemas.response.response_capnp import Status logger = get_logger(__name__) @@ -124,7 +124,7 @@ def deserialize_message( return inference_request -def build_failure_reply(status: "StatusEnum", message: str) -> Response: +def build_failure_reply(status: "Status", message: str) -> Response: return MessageHandler.build_response( status=status, # todo: need to indicate correct status message=message, # todo: decide what these will be @@ -450,4 +450,4 @@ def _can_shutdown(self) -> bool: # if time_diff.total_seconds() > self._cooldown: # return True # return False - return self._worker is None \ No newline at end of file + return self._worker is None diff --git a/smartsim/_core/mli/infrastructure/worker/worker.py b/smartsim/_core/mli/infrastructure/worker/worker.py index 3d3b36fbdf..dd874abe39 100644 --- a/smartsim/_core/mli/infrastructure/worker/worker.py +++ b/smartsim/_core/mli/infrastructure/worker/worker.py @@ -34,7 +34,7 @@ from ...mli_schemas.model.model_capnp import Model if t.TYPE_CHECKING: - from smartsim._core.mli.mli_schemas.response.response_capnp import StatusEnum + from smartsim._core.mli.mli_schemas.response.response_capnp import Status logger = get_logger(__name__) @@ -73,7 +73,7 @@ def __init__( self, outputs: t.Optional[t.Collection[t.Any]] = None, output_keys: t.Optional[t.Collection[str]] = None, - status_enum: "StatusEnum" = "running", + status_enum: "Status" = "running", message: str = "In progress", ) -> None: """Initialize the object""" diff --git a/smartsim/_core/mli/message_handler.py b/smartsim/_core/mli/message_handler.py index bcf1cfdf14..6fd66e6db7 100644 --- a/smartsim/_core/mli/message_handler.py +++ b/smartsim/_core/mli/message_handler.py @@ -405,7 +405,7 @@ def deserialize_request(request_bytes: bytes) -> request_capnp.Request: @staticmethod def _assign_status( - response: response_capnp.Response, status: "response_capnp.StatusEnum" + response: response_capnp.Response, status: "response_capnp.Status" ) -> None: """ Assigns a status to the supplied response. @@ -498,7 +498,7 @@ def _assign_custom_response_attributes( @staticmethod def build_response( - status: "response_capnp.StatusEnum", + status: "response_capnp.Status", message: str, result: t.Union[ t.List[tensor_capnp.Tensor], t.List[data_references_capnp.TensorKey] diff --git a/smartsim/_core/mli/mli_schemas/response/response.capnp b/smartsim/_core/mli/mli_schemas/response/response.capnp index 61df9d7104..83aa05a41b 100644 --- a/smartsim/_core/mli/mli_schemas/response/response.capnp +++ b/smartsim/_core/mli/mli_schemas/response/response.capnp @@ -30,7 +30,7 @@ using Tensors = import "../tensor/tensor.capnp"; using ResponseAttributes = import "response_attributes/response_attributes.capnp"; using DataRef = import "../data/data_references.capnp"; -enum StatusEnum { +enum Status { complete @0; fail @1; timeout @2; @@ -38,7 +38,7 @@ enum StatusEnum { } struct Response { - status @0 :StatusEnum; + status @0 :Status; message @1 :Text; result :union { keys @2 :List(DataRef.TensorKey); diff --git a/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi b/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi index e20f3b79ee..f19bdefe04 100644 --- a/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi +++ b/smartsim/_core/mli/mli_schemas/response/response_capnp.pyi @@ -45,7 +45,7 @@ from .response_attributes.response_attributes_capnp import ( TorchResponseAttributesReader, ) -StatusEnum = Literal["complete", "fail", "timeout", "running"] +Status = Literal["complete", "fail", "timeout", "running"] class Response: class Result: @@ -150,7 +150,7 @@ class Response: def write(file: BufferedWriter) -> None: ... @staticmethod def write_packed(file: BufferedWriter) -> None: ... - status: StatusEnum + status: Status message: str result: Response.Result | Response.ResultBuilder | Response.ResultReader customAttributes: ( diff --git a/tests/dragon/test_reply_building.py b/tests/dragon/test_reply_building.py index f585096788..d1c4d226bb 100644 --- a/tests/dragon/test_reply_building.py +++ b/tests/dragon/test_reply_building.py @@ -37,7 +37,7 @@ from smartsim._core.mli.infrastructure.worker.worker import InferenceReply if t.TYPE_CHECKING: - from smartsim._core.mli.mli_schemas.response.response_capnp import StatusEnum + from smartsim._core.mli.mli_schemas.response.response_capnp import Status # The tests in this file belong to the dragon group pytestmark = pytest.mark.dragon @@ -50,7 +50,7 @@ pytest.param("fail", "Failed while executing", id="fail"), ], ) -def test_build_failure_reply(status: "StatusEnum", message: str): +def test_build_failure_reply(status: "Status", message: str): "Ensures failure replies can be built successfully" response = build_failure_reply(status, message) assert response.status == status @@ -58,7 +58,7 @@ def test_build_failure_reply(status: "StatusEnum", message: str): def test_build_failure_reply_fails(): - "Ensures ValueError is raised if a StatusEnum is not used" + "Ensures ValueError is raised if a Status Enum is not used" with pytest.raises(ValueError) as ex: response = build_failure_reply("not a status enum", "message") @@ -71,7 +71,7 @@ def test_build_failure_reply_fails(): pytest.param("complete", "Success", id="complete"), ], ) -def test_build_reply(status: "StatusEnum", message: str): +def test_build_reply(status: "Status", message: str): "Ensures replies can be built successfully" reply = InferenceReply() reply.status_enum = status @@ -82,7 +82,7 @@ def test_build_reply(status: "StatusEnum", message: str): def test_build_reply_fails(): - "Ensures ValueError is raised if a StatusEnum is not used" + "Ensures ValueError is raised if a Status Enum is not used" with pytest.raises(ValueError) as ex: reply = InferenceReply() reply.status_enum = "not a status enum" From 1952561e682e9156b36148a548666aad8c897dc6 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Tue, 16 Jul 2024 12:39:50 -0500 Subject: [PATCH 46/57] more fixes --- tests/dragon/test_error_handling.py | 5 ++--- tests/dragon/utils/channel.py | 4 ++++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/dragon/test_error_handling.py b/tests/dragon/test_error_handling.py index 4ced3e0ff2..2e983c9220 100644 --- a/tests/dragon/test_error_handling.py +++ b/tests/dragon/test_error_handling.py @@ -89,8 +89,7 @@ def setup_worker_manager(test_dir, monkeypatch: pytest.MonkeyPatch): test_dir, model, [tensor_key], [tensor_key], [], None ) ser_request = MessageHandler.serialize_request(request) - new_sender = worker_manager._task_queue.sendh(use_main_as_stream_channel=True) - new_sender.send_bytes(ser_request) + worker_manager._task_queue.send(ser_request) return worker_manager, integrated_worker @@ -163,7 +162,7 @@ def test_pipeline_stage_errors_handled( monkeypatch.setattr( integrated_worker, "fetch_inputs", - MagicMock(return_value=FetchInputResult([b"result_bytes"])), + MagicMock(return_value=FetchInputResult([b"result_bytes"], None)), ) if stage not in ["fetch_model", "load_model", "fetch_inputs", "transform_input"]: monkeypatch.setattr( diff --git a/tests/dragon/utils/channel.py b/tests/dragon/utils/channel.py index 4bc2014ea3..a559f310ae 100644 --- a/tests/dragon/utils/channel.py +++ b/tests/dragon/utils/channel.py @@ -57,3 +57,7 @@ def send(self, value: bytes) -> None: f"Channel {self.descriptor.decode('utf-8')} sending message to {self._file_path}" ) self._file_path.write_bytes(value) + + def recv(self) -> bytes: + """Receieve a message through the underlying communication channel + :returns: the received message""" From 48035c73548d96cb4d14d24c25f75957e2bcd1e5 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Tue, 16 Jul 2024 12:41:35 -0500 Subject: [PATCH 47/57] another fix --- tests/dragon/utils/channel.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/dragon/utils/channel.py b/tests/dragon/utils/channel.py index a559f310ae..df76c484b5 100644 --- a/tests/dragon/utils/channel.py +++ b/tests/dragon/utils/channel.py @@ -61,3 +61,4 @@ def send(self, value: bytes) -> None: def recv(self) -> bytes: """Receieve a message through the underlying communication channel :returns: the received message""" + ... From e07a2c8f745e986f149a62220784274831e7cd61 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Tue, 16 Jul 2024 15:01:26 -0500 Subject: [PATCH 48/57] pr comments --- .../infrastructure/control/workermanager.py | 29 +++++++++---------- tests/dragon/test_error_handling.py | 18 ++++++++---- 2 files changed, 25 insertions(+), 22 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 5c61fcc0f8..4bfcfd1e30 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -173,8 +173,7 @@ def build_reply(reply: InferenceReply) -> Response: def exception_handler( exc: Exception, reply_channel: t.Optional[CommChannelBase], - func_descriptor: str, - reply: InferenceReply, + failure_message: str ) -> None: """ Logs exceptions, sets reply attributes, and sends the @@ -186,13 +185,11 @@ def exception_handler( :param reply: InferenceReply to modify """ logger.exception( - f"An error occurred while {func_descriptor}.\n" + f"{failure_message}.\n" f"Exception type: {type(exc).__name__}.\n" f"Exception message: {str(exc)}" ) - reply.status_enum = "fail" - reply.message = f"Failed while {func_descriptor}." - response = build_failure_reply(reply.status_enum, reply.message) + response = build_failure_reply('fail', failure_message) serialized_resp = MessageHandler.serialize_response(response) # type: ignore if reply_channel: reply_channel.send(serialized_resp) @@ -317,7 +314,7 @@ def _on_iteration(self) -> None: time.sleep(0.1) except Exception as e: exception_handler( - e, request.callback, "fetching the model", reply + e, request.callback, "Failed while fetching the model." ) break return @@ -342,7 +339,7 @@ def _on_iteration(self) -> None: self._cached_models[request.model_key] = model_result.model except Exception as e: exception_handler( - e, request.callback, "loading the model", reply + e, request.callback, "Failed while loading the model." ) return @@ -352,14 +349,14 @@ def _on_iteration(self) -> None: request, self._feature_store ) except Exception as e: - exception_handler(e, request.callback, "fetching the model", reply) + exception_handler(e, request.callback, "Failed while fetching the model.") return try: model_result = self._worker.load_model( request, fetch_result=fetch_model_result, device=self._device ) except Exception as e: - exception_handler(e, request.callback, "loading the model", reply) + exception_handler(e, request.callback, "Failed while loading the model.") return timings.append(time.perf_counter() - interm) # timing @@ -367,7 +364,7 @@ def _on_iteration(self) -> None: try: fetch_input_result = self._worker.fetch_inputs(request, self._feature_store) except Exception as e: - exception_handler(e, request.callback, "fetching the inputs", reply) + exception_handler(e, request.callback, "Failed while fetching the inputs.") return timings.append(time.perf_counter() - interm) # timing @@ -377,7 +374,7 @@ def _on_iteration(self) -> None: request, fetch_input_result, self._device ) except Exception as e: - exception_handler(e, request.callback, "transforming the input", reply) + exception_handler(e, request.callback, "Failed while transforming the input.") return timings.append(time.perf_counter() - interm) # timing @@ -388,7 +385,7 @@ def _on_iteration(self) -> None: request, model_result, transformed_input ) except Exception as e: - exception_handler(e, request.callback, "executing", reply) + exception_handler(e, request.callback, "Failed while executing.") return timings.append(time.perf_counter() - interm) # timing @@ -398,7 +395,7 @@ def _on_iteration(self) -> None: request, execute_result, self._device ) except Exception as e: - exception_handler(e, request.callback, "transforming the output", reply) + exception_handler(e, request.callback, "Failed while transforming the output.") return timings.append(time.perf_counter() - interm) # timing @@ -411,7 +408,7 @@ def _on_iteration(self) -> None: self._feature_store, ) except Exception as e: - exception_handler(e, request.callback, "placing the output", reply) + exception_handler(e, request.callback, "Failed while placing the output.") return else: reply.outputs = transformed_output.outputs @@ -420,7 +417,7 @@ def _on_iteration(self) -> None: interm = time.perf_counter() # timing if reply.outputs is None or not reply.outputs: - response = build_failure_reply("fail", "no-results") + response = build_failure_reply("fail", "Outputs not found.") else: reply.status_enum = "complete" reply.message = "Success" diff --git a/tests/dragon/test_error_handling.py b/tests/dragon/test_error_handling.py index 2e983c9220..4b436d62e0 100644 --- a/tests/dragon/test_error_handling.py +++ b/tests/dragon/test_error_handling.py @@ -105,8 +105,8 @@ def mock_stage(*args, **kwargs): mock_reply_fn, ) - def mock_exception_handler(exc, reply_channel, func_descriptor, reply): - return exception_handler(exc, None, func_descriptor, reply) + def mock_exception_handler(exc, reply_channel, failure_message): + return exception_handler(exc, None, failure_message) monkeypatch.setattr( "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", @@ -204,13 +204,19 @@ def test_pipeline_stage_errors_handled( mock_reply_fn.assert_called_with("fail", error_message) -def test_exception_handling_helper(): +def test_exception_handling_helper(monkeypatch: pytest.MonkeyPatch): """Ensures that the worker manager does not crash after a failure in the execute pipeline stage""" reply = InferenceReply() + mock_reply_fn = MagicMock() + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.workermanager.build_failure_reply", + mock_reply_fn, + ) + test_exception = ValueError("Test ValueError") - exception_handler(test_exception, None, "fetching the model", reply) + exception_handler(test_exception, None, "Failure while fetching the model.") - assert reply.status_enum == "fail" - assert reply.message == "Failed while fetching the model." + assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_with("fail", "Failure while fetching the model.") From 042e56d66a52341ca6c03042f68953627a163f18 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Tue, 16 Jul 2024 15:09:02 -0500 Subject: [PATCH 49/57] remove extra punctuation --- smartsim/_core/mli/infrastructure/control/workermanager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 4bfcfd1e30..86f6759b9f 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -186,7 +186,7 @@ def exception_handler( """ logger.exception( f"{failure_message}.\n" - f"Exception type: {type(exc).__name__}.\n" + f"Exception type: {type(exc).__name__}\n" f"Exception message: {str(exc)}" ) response = build_failure_reply('fail', failure_message) From ec322351eefc495ef15c9ac67217d73e918cb6ac Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Tue, 16 Jul 2024 15:36:10 -0500 Subject: [PATCH 50/57] style --- .../infrastructure/control/workermanager.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 86f6759b9f..9d55ad0600 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -171,9 +171,7 @@ def build_reply(reply: InferenceReply) -> Response: def exception_handler( - exc: Exception, - reply_channel: t.Optional[CommChannelBase], - failure_message: str + exc: Exception, reply_channel: t.Optional[CommChannelBase], failure_message: str ) -> None: """ Logs exceptions, sets reply attributes, and sends the @@ -189,7 +187,7 @@ def exception_handler( f"Exception type: {type(exc).__name__}\n" f"Exception message: {str(exc)}" ) - response = build_failure_reply('fail', failure_message) + response = build_failure_reply("fail", failure_message) serialized_resp = MessageHandler.serialize_response(response) # type: ignore if reply_channel: reply_channel.send(serialized_resp) @@ -349,14 +347,18 @@ def _on_iteration(self) -> None: request, self._feature_store ) except Exception as e: - exception_handler(e, request.callback, "Failed while fetching the model.") + exception_handler( + e, request.callback, "Failed while fetching the model." + ) return try: model_result = self._worker.load_model( request, fetch_result=fetch_model_result, device=self._device ) except Exception as e: - exception_handler(e, request.callback, "Failed while loading the model.") + exception_handler( + e, request.callback, "Failed while loading the model." + ) return timings.append(time.perf_counter() - interm) # timing @@ -374,7 +376,9 @@ def _on_iteration(self) -> None: request, fetch_input_result, self._device ) except Exception as e: - exception_handler(e, request.callback, "Failed while transforming the input.") + exception_handler( + e, request.callback, "Failed while transforming the input." + ) return timings.append(time.perf_counter() - interm) # timing @@ -395,7 +399,9 @@ def _on_iteration(self) -> None: request, execute_result, self._device ) except Exception as e: - exception_handler(e, request.callback, "Failed while transforming the output.") + exception_handler( + e, request.callback, "Failed while transforming the output." + ) return timings.append(time.perf_counter() - interm) # timing @@ -408,7 +414,9 @@ def _on_iteration(self) -> None: self._feature_store, ) except Exception as e: - exception_handler(e, request.callback, "Failed while placing the output.") + exception_handler( + e, request.callback, "Failed while placing the output." + ) return else: reply.outputs = transformed_output.outputs From 7e53d08cdc9b2179c440eda95bcb2683d018f42c Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Tue, 16 Jul 2024 16:12:10 -0500 Subject: [PATCH 51/57] more punctuation --- smartsim/_core/mli/infrastructure/control/workermanager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 9d55ad0600..f372ae513f 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -183,7 +183,7 @@ def exception_handler( :param reply: InferenceReply to modify """ logger.exception( - f"{failure_message}.\n" + f"{failure_message}\n" f"Exception type: {type(exc).__name__}\n" f"Exception message: {str(exc)}" ) From d986164cb526776b7206f80ddb2bc3f7a13867e6 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Tue, 16 Jul 2024 19:36:01 -0500 Subject: [PATCH 52/57] send_failure and tests --- .../infrastructure/control/workermanager.py | 82 +++++++++---------- tests/dragon/test_error_handling.py | 67 +++++++++++++-- 2 files changed, 95 insertions(+), 54 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index f372ae513f..2c9dfd592e 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -174,21 +174,30 @@ def exception_handler( exc: Exception, reply_channel: t.Optional[CommChannelBase], failure_message: str ) -> None: """ - Logs exceptions, sets reply attributes, and sends the - failure response without taking down the WorkerManager. + Logs exceptions, calls send_failure to send the failed response back. :param exc: The exception to be logged :param reply_channel: The channel used to send replies - :param func_descriptor: Descriptor to help form error messages - :param reply: InferenceReply to modify + :param failure_message: Failure message to log and send back """ logger.exception( f"{failure_message}\n" f"Exception type: {type(exc).__name__}\n" f"Exception message: {str(exc)}" ) - response = build_failure_reply("fail", failure_message) - serialized_resp = MessageHandler.serialize_response(response) # type: ignore + send_failure(reply_channel, failure_message) + + +def send_failure( + reply_channel: t.Optional[CommChannelBase], failure_message: str +) -> None: + """ + Sends back the failed response. + + :param reply_channel: The channel used to send replies + :param failure_message: Failure message for response + """ + serialized_resp = MessageHandler.serialize_response(build_failure_reply("fail", failure_message)) # type: ignore if reply_channel: reply_channel.send(serialized_resp) @@ -290,10 +299,7 @@ def _on_iteration(self) -> None: if not request.raw_model: if request.model_key is None: - response = build_failure_reply("fail", "Could not read model key.") - serialized_resp = MessageHandler.serialize_response(response) # type: ignore - if request.callback: - request.callback.send(serialized_resp) + send_failure(request.callback, "Could not find model key or model.") return if request.model_key in self._cached_models: timings.append(time.perf_counter() - interm) # timing @@ -301,45 +307,31 @@ def _on_iteration(self) -> None: model_result = LoadModelResult(self._cached_models[request.model_key]) else: - fetch_model_result = None - while True: - try: - interm = time.perf_counter() # timing - fetch_model_result = self._worker.fetch_model( - request, self._feature_store - ) - except KeyError: - time.sleep(0.1) - except Exception as e: - exception_handler( - e, request.callback, "Failed while fetching the model." - ) - break + try: + interm = time.perf_counter() # timing + fetch_model_result = self._worker.fetch_model( + request, self._feature_store + ) + except Exception as e: + exception_handler( + e, request.callback, "Failed while fetching the model." + ) return - if fetch_model_result is None: - response = build_failure_reply( - "fail", "Could not retrieve model from feature store." + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing + try: + model_result = self._worker.load_model( + request, + fetch_result=fetch_model_result, + device=self._device, + ) + self._cached_models[request.model_key] = model_result.model + except Exception as e: + exception_handler( + e, request.callback, "Failed while loading the model." ) - serialized_resp = MessageHandler.serialize_response(response) # type: ignore - if request.callback: - request.callback.send(serialized_resp) return - else: - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - try: - model_result = self._worker.load_model( - request, - fetch_result=fetch_model_result, - device=self._device, - ) - self._cached_models[request.model_key] = model_result.model - except Exception as e: - exception_handler( - e, request.callback, "Failed while loading the model." - ) - return else: try: diff --git a/tests/dragon/test_error_handling.py b/tests/dragon/test_error_handling.py index 4b436d62e0..9df914e044 100644 --- a/tests/dragon/test_error_handling.py +++ b/tests/dragon/test_error_handling.py @@ -40,6 +40,7 @@ from smartsim._core.mli.infrastructure.control.workermanager import ( WorkerManager, exception_handler, + send_failure, ) from smartsim._core.mli.infrastructure.environmentloader import EnvironmentConfigLoader from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import ( @@ -48,6 +49,7 @@ from smartsim._core.mli.infrastructure.worker.worker import ( ExecuteResult, FetchInputResult, + FetchModelResult, InferenceReply, LoadModelResult, TransformInputResult, @@ -63,7 +65,7 @@ @pytest.fixture -def setup_worker_manager(test_dir, monkeypatch: pytest.MonkeyPatch): +def setup_worker_manager_model_bytes(test_dir, monkeypatch: pytest.MonkeyPatch): integrated_worker = IntegratedTorchWorker() chan = Channel.make_process_local() @@ -94,6 +96,38 @@ def setup_worker_manager(test_dir, monkeypatch: pytest.MonkeyPatch): return worker_manager, integrated_worker +@pytest.fixture +def setup_worker_manager_model_key(test_dir, monkeypatch: pytest.MonkeyPatch): + integrated_worker = IntegratedTorchWorker() + + chan = Channel.make_process_local() + queue = FLInterface(main_ch=chan) + monkeypatch.setenv("SSQueue", du.B64.bytes_to_str(queue.serialize())) + storage = DDict() + feature_store = DragonFeatureStore(storage) + monkeypatch.setenv( + "SSFeatureStore", base64.b64encode(pickle.dumps(feature_store)).decode("utf-8") + ) + + worker_manager = WorkerManager( + EnvironmentConfigLoader(), + integrated_worker, + as_service=False, + cooldown=3, + comm_channel_type=FileSystemCommChannel, + ) + + tensor_key = MessageHandler.build_tensor_key("key") + model_key = MessageHandler.build_model_key("model key") + request = MessageHandler.build_request( + test_dir, model_key, [tensor_key], [tensor_key], [], None + ) + ser_request = MessageHandler.serialize_request(request) + worker_manager._task_queue.send(ser_request) + + return worker_manager, integrated_worker + + def mock_pipeline_stage(monkeypatch: pytest.MonkeyPatch, integrated_worker, stage): def mock_stage(*args, **kwargs): raise ValueError(f"Simulated error in {stage}") @@ -105,16 +139,24 @@ def mock_stage(*args, **kwargs): mock_reply_fn, ) - def mock_exception_handler(exc, reply_channel, failure_message): - return exception_handler(exc, None, failure_message) + def mock_send_failure(reply_channel, failure_message): + return send_failure(None, failure_message) monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", - mock_exception_handler, + "smartsim._core.mli.infrastructure.control.workermanager.send_failure", + mock_send_failure, ) + return mock_reply_fn +@pytest.mark.parametrize( + "setup_worker_manager", + [ + pytest.param("setup_worker_manager_model_bytes"), + pytest.param("setup_worker_manager_model_key"), + ], +) @pytest.mark.parametrize( "stage, error_message", [ @@ -142,16 +184,23 @@ def mock_exception_handler(exc, reply_channel, failure_message): ], ) def test_pipeline_stage_errors_handled( + request, setup_worker_manager, monkeypatch: pytest.MonkeyPatch, stage: str, error_message: str, ): """Ensures that the worker manager does not crash after a failure in various pipeline stages""" - worker_manager, integrated_worker = setup_worker_manager - + worker_manager, integrated_worker = request.getfixturevalue(setup_worker_manager) mock_reply_fn = mock_pipeline_stage(monkeypatch, integrated_worker, stage) + if stage not in ["fetch_model"]: + monkeypatch.setattr( + integrated_worker, + "fetch_model", + MagicMock(return_value=FetchModelResult(b"result_bytes")), + ) + if stage not in ["fetch_model", "load_model"]: monkeypatch.setattr( integrated_worker, @@ -200,7 +249,7 @@ def test_pipeline_stage_errors_handled( worker_manager._on_iteration() - assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_once() mock_reply_fn.assert_called_with("fail", error_message) @@ -218,5 +267,5 @@ def test_exception_handling_helper(monkeypatch: pytest.MonkeyPatch): test_exception = ValueError("Test ValueError") exception_handler(test_exception, None, "Failure while fetching the model.") - assert mock_reply_fn.called_once() + mock_reply_fn.assert_called_once() mock_reply_fn.assert_called_with("fail", "Failure while fetching the model.") From d9e22fb739e8a41e5c73eeb8a2cf6ad4d0f468d7 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 17 Jul 2024 11:17:36 -0500 Subject: [PATCH 53/57] reintroduce questionable while loop --- .../infrastructure/control/workermanager.py | 26 ++++++++++++------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 2c9dfd592e..f3b7b58fda 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -307,16 +307,22 @@ def _on_iteration(self) -> None: model_result = LoadModelResult(self._cached_models[request.model_key]) else: - try: - interm = time.perf_counter() # timing - fetch_model_result = self._worker.fetch_model( - request, self._feature_store - ) - except Exception as e: - exception_handler( - e, request.callback, "Failed while fetching the model." - ) - return + fetch_model_result = None + while fetch_model_result is None: + try: + interm = time.perf_counter() # timing + fetch_model_result = self._worker.fetch_model( + request, self._feature_store + ) + + # do we want to keep this? it could cause an infinite loop + except KeyError: + time.sleep(0.01) + except Exception as e: + exception_handler( + e, request.callback, "Failed while fetching the model." + ) + return timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing From 3caf24beb76a070076a65d009677feac32e4c331 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 17 Jul 2024 11:49:27 -0500 Subject: [PATCH 54/57] remove send_failure --- .../infrastructure/control/workermanager.py | 20 ++++++------------- tests/dragon/test_error_handling.py | 9 ++++----- 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index f3b7b58fda..a5d3bbd579 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -174,7 +174,7 @@ def exception_handler( exc: Exception, reply_channel: t.Optional[CommChannelBase], failure_message: str ) -> None: """ - Logs exceptions, calls send_failure to send the failed response back. + Logs exceptions and sends a failure response. :param exc: The exception to be logged :param reply_channel: The channel used to send replies @@ -185,18 +185,6 @@ def exception_handler( f"Exception type: {type(exc).__name__}\n" f"Exception message: {str(exc)}" ) - send_failure(reply_channel, failure_message) - - -def send_failure( - reply_channel: t.Optional[CommChannelBase], failure_message: str -) -> None: - """ - Sends back the failed response. - - :param reply_channel: The channel used to send replies - :param failure_message: Failure message for response - """ serialized_resp = MessageHandler.serialize_response(build_failure_reply("fail", failure_message)) # type: ignore if reply_channel: reply_channel.send(serialized_resp) @@ -299,7 +287,11 @@ def _on_iteration(self) -> None: if not request.raw_model: if request.model_key is None: - send_failure(request.callback, "Could not find model key or model.") + exception_handler( + ValueError("Could not find model key or model"), + request.callback, + "Could not find model key or model.", + ) return if request.model_key in self._cached_models: timings.append(time.perf_counter() - interm) # timing diff --git a/tests/dragon/test_error_handling.py b/tests/dragon/test_error_handling.py index 9df914e044..151bdd2fcc 100644 --- a/tests/dragon/test_error_handling.py +++ b/tests/dragon/test_error_handling.py @@ -40,7 +40,6 @@ from smartsim._core.mli.infrastructure.control.workermanager import ( WorkerManager, exception_handler, - send_failure, ) from smartsim._core.mli.infrastructure.environmentloader import EnvironmentConfigLoader from smartsim._core.mli.infrastructure.storage.dragonfeaturestore import ( @@ -139,12 +138,12 @@ def mock_stage(*args, **kwargs): mock_reply_fn, ) - def mock_send_failure(reply_channel, failure_message): - return send_failure(None, failure_message) + def mock_exception_handler(exc, reply_channel, failure_message): + return exception_handler(exc, None, failure_message) monkeypatch.setattr( - "smartsim._core.mli.infrastructure.control.workermanager.send_failure", - mock_send_failure, + "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", + mock_exception_handler, ) return mock_reply_fn From be1ddea5688dbd132e12d60f78b74cf3a12aa4f4 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 17 Jul 2024 12:06:01 -0500 Subject: [PATCH 55/57] adjust and add timing --- .../_core/mli/infrastructure/control/workermanager.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index a5d3bbd579..6569ca5f08 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -301,8 +301,9 @@ def _on_iteration(self) -> None: else: fetch_model_result = None while fetch_model_result is None: + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing try: - interm = time.perf_counter() # timing fetch_model_result = self._worker.fetch_model( request, self._feature_store ) @@ -332,6 +333,8 @@ def _on_iteration(self) -> None: return else: + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing try: fetch_model_result = self._worker.fetch_model( request, self._feature_store @@ -341,6 +344,9 @@ def _on_iteration(self) -> None: e, request.callback, "Failed while fetching the model." ) return + + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing try: model_result = self._worker.load_model( request, fetch_result=fetch_model_result, device=self._device @@ -373,7 +379,6 @@ def _on_iteration(self) -> None: timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing - try: execute_result = self._worker.execute( request, model_result, transformed_input From b3927c711f95170b0c7a223c8efdc66d0bc21d28 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 17 Jul 2024 13:55:30 -0500 Subject: [PATCH 56/57] remove while loop --- .../infrastructure/control/workermanager.py | 28 ++++++++----------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index 6569ca5f08..d22e25ff31 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -299,23 +299,17 @@ def _on_iteration(self) -> None: model_result = LoadModelResult(self._cached_models[request.model_key]) else: - fetch_model_result = None - while fetch_model_result is None: - timings.append(time.perf_counter() - interm) # timing - interm = time.perf_counter() # timing - try: - fetch_model_result = self._worker.fetch_model( - request, self._feature_store - ) - - # do we want to keep this? it could cause an infinite loop - except KeyError: - time.sleep(0.01) - except Exception as e: - exception_handler( - e, request.callback, "Failed while fetching the model." - ) - return + timings.append(time.perf_counter() - interm) # timing + interm = time.perf_counter() # timing + try: + fetch_model_result = self._worker.fetch_model( + request, self._feature_store + ) + except Exception as e: + exception_handler( + e, request.callback, "Failed while fetching the model." + ) + return timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing From 7f104b06d9af1e409a5d3ccea71b109a3b2797c8 Mon Sep 17 00:00:00 2001 From: Alyssa Cote Date: Wed, 17 Jul 2024 15:23:22 -0500 Subject: [PATCH 57/57] remove comments, remove type ignore from workermanager --- .../mli/infrastructure/control/workermanager.py | 16 +++++++++------- smartsim/_core/mli/message_handler.py | 4 ++-- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/smartsim/_core/mli/infrastructure/control/workermanager.py b/smartsim/_core/mli/infrastructure/control/workermanager.py index d22e25ff31..8e3ed3fb4c 100644 --- a/smartsim/_core/mli/infrastructure/control/workermanager.py +++ b/smartsim/_core/mli/infrastructure/control/workermanager.py @@ -51,7 +51,7 @@ MachineLearningWorkerBase, ) from ...message_handler import MessageHandler -from ...mli_schemas.response.response_capnp import Response +from ...mli_schemas.response.response_capnp import Response, ResponseBuilder if t.TYPE_CHECKING: from dragon.fli import FLInterface @@ -124,10 +124,10 @@ def deserialize_message( return inference_request -def build_failure_reply(status: "Status", message: str) -> Response: +def build_failure_reply(status: "Status", message: str) -> ResponseBuilder: return MessageHandler.build_response( - status=status, # todo: need to indicate correct status - message=message, # todo: decide what these will be + status=status, + message=message, result=[], custom_attributes=None, ) @@ -159,7 +159,7 @@ def prepare_outputs(reply: InferenceReply) -> t.List[t.Any]: return prepared_outputs -def build_reply(reply: InferenceReply) -> Response: +def build_reply(reply: InferenceReply) -> ResponseBuilder: results = prepare_outputs(reply) return MessageHandler.build_response( @@ -185,7 +185,9 @@ def exception_handler( f"Exception type: {type(exc).__name__}\n" f"Exception message: {str(exc)}" ) - serialized_resp = MessageHandler.serialize_response(build_failure_reply("fail", failure_message)) # type: ignore + serialized_resp = MessageHandler.serialize_response( + build_failure_reply("fail", failure_message) + ) if reply_channel: reply_channel.send(serialized_resp) @@ -423,7 +425,7 @@ def _on_iteration(self) -> None: timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing - serialized_resp = MessageHandler.serialize_response(response) # type: ignore + serialized_resp = MessageHandler.serialize_response(response) timings.append(time.perf_counter() - interm) # timing interm = time.perf_counter() # timing diff --git a/smartsim/_core/mli/message_handler.py b/smartsim/_core/mli/message_handler.py index 6fd66e6db7..4fe2bef3a7 100644 --- a/smartsim/_core/mli/message_handler.py +++ b/smartsim/_core/mli/message_handler.py @@ -360,7 +360,7 @@ def build_request( request_attributes_capnp.TensorFlowRequestAttributes, None, ], - ) -> request_capnp.Request: + ) -> request_capnp.RequestBuilder: """ Builds the request message. @@ -508,7 +508,7 @@ def build_response( response_attributes_capnp.TensorFlowResponseAttributes, None, ], - ) -> response_capnp.Response: + ) -> response_capnp.ResponseBuilder: """ Builds the response message.