diff --git a/doc/changelog.md b/doc/changelog.md index ac09ecf604..4ce6cf586c 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -13,6 +13,7 @@ Jump to: Description +- Refactor `exception_handler` - Add RequestDispatcher and the possibility of batching inference requests - Enable hostname selection for dragon tasks - Remove pydantic dependency from MLI code diff --git a/smartsim/_core/mli/infrastructure/control/error_handling.py b/smartsim/_core/mli/infrastructure/control/error_handling.py index e2c5bcd9e1..5a42a8bfa8 100644 --- a/smartsim/_core/mli/infrastructure/control/error_handling.py +++ b/smartsim/_core/mli/infrastructure/control/error_handling.py @@ -61,10 +61,10 @@ def exception_handler( f"Exception type: {type(exc).__name__}\n" f"Exception message: {str(exc)}" ) - serialized_resp = MessageHandler.serialize_response( - build_failure_reply("fail", failure_message) - ) if reply_channel: + serialized_resp = MessageHandler.serialize_response( + build_failure_reply("fail", failure_message) + ) reply_channel.send(serialized_resp) else: - logger.warning("Unable to notify client of error without reply_channel") + logger.warning("Unable to notify client of error without a reply channel") diff --git a/tests/dragon/test_error_handling.py b/tests/dragon/test_error_handling.py index b20424866a..0e737101fa 100644 --- a/tests/dragon/test_error_handling.py +++ b/tests/dragon/test_error_handling.py @@ -307,14 +307,22 @@ def mock_stage(*args, **kwargs): mock_reply_fn, ) + mock_reply_channel = MagicMock() + mock_reply_channel.send = MagicMock() + def mock_exception_handler(exc, reply_channel, failure_message): - return exception_handler(exc, None, failure_message) + return exception_handler(exc, mock_reply_channel, failure_message) monkeypatch.setattr( "smartsim._core.mli.infrastructure.control.workermanager.exception_handler", mock_exception_handler, ) + monkeypatch.setattr( + "smartsim._core.mli.infrastructure.control.requestdispatcher.exception_handler", + mock_exception_handler, + ) + return mock_reply_fn @@ -464,7 +472,9 @@ def test_dispatcher_pipeline_stage_errors_handled( def test_exception_handling_helper(monkeypatch: pytest.MonkeyPatch): """Ensures that the worker manager does not crash after a failure in the execute pipeline stage""" - reply = InferenceReply() + + mock_reply_channel = MagicMock() + mock_reply_channel.send = MagicMock() mock_reply_fn = MagicMock() monkeypatch.setattr( @@ -473,7 +483,9 @@ def test_exception_handling_helper(monkeypatch: pytest.MonkeyPatch): ) test_exception = ValueError("Test ValueError") - exception_handler(test_exception, None, "Failure while fetching the model.") + exception_handler( + test_exception, mock_reply_channel, "Failure while fetching the model." + ) mock_reply_fn.assert_called_once() mock_reply_fn.assert_called_with("fail", "Failure while fetching the model.")