Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,8 @@ private void response() {
ModelTensors tensors = processOutput(action, body, connector, scriptService, parameters, mlGuard);
tensors.setStatusCode(statusCode);
actionListener.onResponse(new Tuple<>(executionContext.getSequence(), tensors));
} catch (IllegalArgumentException e) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also try to catch the OpensearchStatus Exception too in another catch block?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good catch would need to look what scenarios cause a OpensearchStatus to occur

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After checking again the try block does not do any other operations that would produce OpensearchStatusException. All the edge cases would only result in a illegalArgumentException

ModelTensors tensors = processOutput(action, body, connector, scriptService, parameters, mlGuard);

If we happen to find such error we will be able to see if through the log statement on the generic exception which would be very unlikely. The only other method that can throw a exception would be the

connector.parseResponse method which throws a IOException

connector.parseResponse(filteredResponse, modelTensors, scriptReturnModelTensor);

TLDR: adding a catch block OpensearchStatusException would be dead code, and making a test for it wouldnt be useful as we would just be mocking a scenario that doesn't exist

actionListener.onFailure(e);
} catch (Exception e) {
log.error("Failed to process response body: {}", body, e);
actionListener.onFailure(new MLException("Fail to execute " + action + " in aws connector", e));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.engine.algorithms.remote;

import static org.junit.Assert.assertEquals;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
Expand All @@ -31,7 +32,6 @@
import org.opensearch.ml.common.connector.ConnectorAction;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.connector.MLPostProcessFunction;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.script.ScriptService;
import org.reactivestreams.Publisher;
Expand Down Expand Up @@ -191,7 +191,7 @@ public void test_onError() {
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener).onFailure(captor.capture());
assert captor.getValue() instanceof OpenSearchStatusException;
assert captor.getValue().getMessage().equals("Error communicating with remote model: runtime exception");
assertEquals("Error communicating with remote model: runtime exception", captor.getValue().getMessage());
}

@Test
Expand All @@ -209,7 +209,7 @@ public void test_onSubscribe() {
public void test_onNext() {
test_onSubscribe();// set the subscription to non-null.
responseSubscriber.onNext(ByteBuffer.wrap("hello world".getBytes()));
assert mlSdkAsyncHttpResponseHandler.getResponseBody().toString().equals("hello world");
assertEquals("hello world", mlSdkAsyncHttpResponseHandler.getResponseBody().toString());
}

@Test
Expand All @@ -221,7 +221,7 @@ public void test_MLResponseSubscriber_onError() {
ArgumentCaptor<Exception> captor = ArgumentCaptor.forClass(Exception.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue() instanceof OpenSearchStatusException;
assert captor.getValue().getMessage().equals("Remote service returned error status 500 with empty body");
assertEquals("Remote service returned error status 500 with empty body", captor.getValue().getMessage());
}

@Test
Expand Down Expand Up @@ -283,7 +283,7 @@ public void test_onComplete_failed() {
mlSdkAsyncHttpResponseHandler.onStream(stream);
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue().getMessage().equals("Error from remote service: Model current status is: FAILED");
assertEquals("Error from remote service: Model current status is: FAILED", captor.getValue().getMessage());
assert captor.getValue().status().getStatus() == 500;
}

Expand All @@ -302,7 +302,7 @@ public void test_onComplete_empty_response_body() {
mlSdkAsyncHttpResponseHandler.onStream(stream);
ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(OpenSearchStatusException.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue().getMessage().equals("Remote service returned empty response body");
assertEquals("Remote service returned empty response body", captor.getValue().getMessage());
}

@Test
Expand Down Expand Up @@ -380,14 +380,12 @@ public void test_onComplete_throttle_exception_onFailure() {

ArgumentCaptor<OpenSearchStatusException> captor = ArgumentCaptor.forClass(RemoteConnectorThrottlingException.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor
.getValue()
.getMessage()
.equals(
"Error from remote service: The request was denied due to remote server throttling. "
+ "To change the retry policy and behavior, please update the connector client_config."
);
assert captor.getValue().status().getStatus() == HttpStatusCode.BAD_REQUEST;
assertEquals(
"Error from remote service: The request was denied due to remote server throttling. "
+ "To change the retry policy and behavior, please update the connector client_config.",
captor.getValue().getMessage()
);
}

@Test
Expand Down Expand Up @@ -416,8 +414,39 @@ public void test_onComplete_processOutputFail_onFailure() {
};
mlSdkAsyncHttpResponseHandler.onStream(stream);

ArgumentCaptor<MLException> captor = ArgumentCaptor.forClass(MLException.class);
ArgumentCaptor<IllegalArgumentException> captor = ArgumentCaptor.forClass(IllegalArgumentException.class);
verify(actionListener, times(1)).onFailure(captor.capture());
assert captor.getValue().getMessage().equals("Fail to execute PREDICT in aws connector");
assertEquals("no PREDICT action found", captor.getValue().getMessage());
}

/**
* Asserts that IllegalArgumentException is propagated where post-processing function fails
* on response
*/
@Test
public void onComplete_InvalidEmbeddingBedRockPostProcessingOccurs_IllegalArgumentExceptionThrown() {
String invalidEmbeddingResponse = "{ \"embedding\": [[1]] }";

mlSdkAsyncHttpResponseHandler.onHeaders(sdkHttpResponse);
Publisher<ByteBuffer> stream = s -> {
try {
s.onSubscribe(mock(Subscription.class));
s.onNext(ByteBuffer.wrap(invalidEmbeddingResponse.getBytes()));
s.onComplete();
} catch (Throwable e) {
s.onError(e);
}
};
mlSdkAsyncHttpResponseHandler.onStream(stream);

ArgumentCaptor<IllegalArgumentException> exceptionCaptor = ArgumentCaptor.forClass(IllegalArgumentException.class);
verify(actionListener, times(1)).onFailure(exceptionCaptor.capture());

// Error message
assertEquals(
"BedrockEmbeddingPostProcessFunction exception message should match",
"The embedding should be a non-empty List containing Float values.",
exceptionCaptor.getValue().getMessage()
);
}
}
Loading