Skip to content

Commit dbd3f0a

Browse files
Address comments, add some UTs
Signed-off-by: Nathalie Jonathan <[email protected]>
1 parent 7f1b222 commit dbd3f0a

File tree

14 files changed

+352
-85
lines changed

14 files changed

+352
-85
lines changed

common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,10 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
367367
}
368368

369369
private boolean neededStreamParameterInPayload(Map<String, String> parameters) {
370+
if (parameters == null) {
371+
return false;
372+
}
373+
370374
boolean isStream = parameters.containsKey("stream");
371375
if (!isStream) {
372376
return false;

common/src/main/java/org/opensearch/ml/common/transport/prediction/MLPredictionStreamTaskAction.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
16
package org.opensearch.ml.common.transport.prediction;
27

38
import org.opensearch.action.ActionType;

common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,52 @@ public void createPayload_MissingParamsInvalidJson() {
228228
connector.validatePayload(predictPayload);
229229
}
230230

231+
@Test
232+
public void createPayload_WithStreamParameter_OpenAI() {
233+
String requestBody = "{\"model\": \"gpt-3.5-turbo\", \"messages\": [{\"role\": \"user\", \"content\": \"${parameters.input}\"}]}";
234+
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
235+
236+
Map<String, String> parameters = new HashMap<>();
237+
parameters.put("input", "Hello world");
238+
parameters.put("stream", "true");
239+
parameters.put("_llm_interface", "openai/v1/chat/completions");
240+
241+
String payload = connector.createPayload(PREDICT.name(), parameters);
242+
Assert
243+
.assertEquals(
244+
"{\"model\":\"gpt-3.5-turbo\",\"messages\":[{\"role\":\"user\",\"content\":\"Hello world\"}],\"stream\":true}",
245+
payload
246+
);
247+
}
248+
249+
@Test
250+
public void createPayload_WithoutStreamParameter() {
251+
String requestBody = "{\"model\": \"gpt-3.5-turbo\", \"messages\": [{\"role\": \"user\", \"content\": \"${parameters.input}\"}]}";
252+
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
253+
254+
Map<String, String> parameters = new HashMap<>();
255+
parameters.put("input", "Hello world");
256+
parameters.put("_llm_interface", "openai/v1/chat/completions");
257+
258+
String payload = connector.createPayload(PREDICT.name(), parameters);
259+
Assert.assertEquals("{\"model\": \"gpt-3.5-turbo\", \"messages\": [{\"role\": \"user\", \"content\": \"Hello world\"}]}", payload);
260+
}
261+
262+
@Test
263+
public void createPayload_WithStreamParameter_UnsupportedInterface() {
264+
String requestBody = "{\"input\": \"${parameters.input}\"}";
265+
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
266+
267+
Map<String, String> parameters = new HashMap<>();
268+
parameters.put("input", "Hello world");
269+
parameters.put("stream", "true");
270+
parameters.put("_llm_interface", "invalid/interface");
271+
272+
String payload = connector.createPayload(PREDICT.name(), parameters);
273+
274+
Assert.assertEquals("{\"input\": \"Hello world\"}", payload);
275+
}
276+
231277
@Test
232278
public void parseResponse_modelTensorJson() throws IOException {
233279
HttpConnector connector = createHttpConnector();

common/src/test/java/org/opensearch/ml/common/settings/MLCommonsSettingsTests.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,4 +102,9 @@ public void testAgenticMemoryDisabledMessage() {
102102
"The Agentic Memory APIs are not enabled. To enable, please update the setting plugins.ml_commons.agentic_memory_enabled";
103103
assertEquals(expectedMessage, MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_DISABLED_MESSAGE);
104104
}
105+
106+
@Test
107+
public void testStreamDisabledByDefault() {
108+
assertFalse(MLCommonsSettings.ML_COMMONS_STREAM_ENABLED.getDefault(null));
109+
}
105110
}

common/src/test/java/org/opensearch/ml/common/settings/MLFeatureEnabledSettingTests.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ public void setUp() {
4848
MLCommonsSettings.ML_COMMONS_AGENTIC_SEARCH_ENABLED,
4949
MLCommonsSettings.ML_COMMONS_MCP_CONNECTOR_ENABLED,
5050
MLCommonsSettings.ML_COMMONS_AGENTIC_MEMORY_ENABLED,
51-
MLCommonsSettings.ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED
51+
MLCommonsSettings.ML_COMMONS_INDEX_INSIGHT_FEATURE_ENABLED,
52+
MLCommonsSettings.ML_COMMONS_STREAM_ENABLED
5253
)
5354
);
5455
when(mockClusterService.getClusterSettings()).thenReturn(mockClusterSettings);
@@ -73,6 +74,7 @@ public void testDefaults_allFeaturesEnabled() {
7374
.put("plugins.ml_commons.mcp_connector_enabled", true)
7475
.put("plugins.ml_commons.agentic_search_enabled", true)
7576
.put("plugins.ml_commons.agentic_memory_enabled", true)
77+
.put("plugins.ml_commons.stream_enabled", true)
7678
.build();
7779

7880
MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings);
@@ -92,6 +94,7 @@ public void testDefaults_allFeaturesEnabled() {
9294
assertTrue(setting.isMcpConnectorEnabled());
9395
assertTrue(setting.isAgenticSearchEnabled());
9496
assertTrue(setting.isAgenticMemoryEnabled());
97+
assertTrue(setting.isStreamEnabled());
9598
}
9699

97100
@Test
@@ -113,6 +116,7 @@ public void testDefaults_someFeaturesDisabled() {
113116
.put("plugins.ml_commons.mcp_connector_enabled", false)
114117
.put("plugins.ml_commons.agentic_search_enabled", false)
115118
.put("plugins.ml_commons.agentic_memory_enabled", false)
119+
.put("plugins.ml_commons.stream_enabled", false)
116120
.build();
117121

118122
MLFeatureEnabledSetting setting = new MLFeatureEnabledSetting(mockClusterService, settings);
@@ -132,6 +136,7 @@ public void testDefaults_someFeaturesDisabled() {
132136
assertFalse(setting.isMcpConnectorEnabled());
133137
assertFalse(setting.isAgenticSearchEnabled());
134138
assertFalse(setting.isAgenticMemoryEnabled());
139+
assertFalse(setting.isStreamEnabled());
135140
}
136141

137142
@Test

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,11 @@
1313
import static software.amazon.awssdk.http.SdkHttpMethod.POST;
1414

1515
import java.security.AccessController;
16-
import java.security.PrivilegedActionException;
1716
import java.security.PrivilegedExceptionAction;
1817
import java.time.Duration;
1918
import java.util.Locale;
2019
import java.util.Map;
2120
import java.util.concurrent.CompletableFuture;
22-
import java.util.concurrent.TimeUnit;
2321
import java.util.concurrent.atomic.AtomicBoolean;
2422

2523
import org.apache.commons.text.StringEscapeUtils;
@@ -43,7 +41,6 @@
4341
import lombok.Getter;
4442
import lombok.Setter;
4543
import lombok.extern.log4j.Log4j2;
46-
import okhttp3.OkHttpClient;
4744
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
4845
import software.amazon.awssdk.auth.credentials.AwsSessionCredentials;
4946
import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
@@ -83,16 +80,12 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor {
8380

8481
private SdkAsyncHttpClient httpClient;
8582

86-
private OkHttpClient okHttpClient;
87-
8883
private BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient;
8984

9085
@Setter
9186
@Getter
9287
private StreamTransportService streamTransportService;
9388

94-
private AtomicBoolean isStreamClosed = new AtomicBoolean(false);
95-
9689
public AwsConnectorExecutor(Connector connector) {
9790
super.initialize(connector);
9891
this.connector = (AwsConnector) connector;
@@ -101,18 +94,6 @@ public AwsConnectorExecutor(Connector connector) {
10194
Integer maxConnection = super.getConnectorClientConfig().getMaxConnections();
10295
this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection);
10396
this.bedrockRuntimeAsyncClient = buildBedrockRuntimeAsyncClient(httpClient);
104-
try {
105-
AccessController.doPrivileged((PrivilegedExceptionAction<Void>) () -> {
106-
this.okHttpClient = new OkHttpClient.Builder()
107-
.connectTimeout(10, TimeUnit.SECONDS)
108-
.readTimeout(1, TimeUnit.MINUTES)
109-
.retryOnConnectionFailure(true)
110-
.build();
111-
return null;
112-
});
113-
} catch (PrivilegedActionException e) {
114-
throw new RuntimeException("Failed to build OkHttpClient.", e);
115-
}
11697
}
11798

11899
@Override
@@ -179,6 +160,7 @@ public void invokeRemoteServiceStream(
179160
StreamPredictActionListener<MLTaskResponse, ?> actionListener
180161
) {
181162
try {
163+
AtomicBoolean isStreamClosed = new AtomicBoolean(false);
182164
String llmInterface = parameters.get(LLM_INTERFACE);
183165
llmInterface = llmInterface.trim().toLowerCase(Locale.ROOT);
184166
llmInterface = StringEscapeUtils.unescapeJava(llmInterface);
@@ -196,6 +178,7 @@ public void invokeRemoteServiceStream(
196178
}).onError(error -> {
197179
// Handle errors
198180
log.error("Converse stream error: {}", error.getMessage());
181+
actionListener.onFailure(new MLException("Error from remote service: " + error.getMessage(), error));
199182
}).onComplete(() -> {
200183
// Handle completion
201184
log.debug("Converse stream complete");

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import static software.amazon.awssdk.http.SdkHttpMethod.GET;
1313
import static software.amazon.awssdk.http.SdkHttpMethod.POST;
1414

15+
import java.io.IOException;
1516
import java.net.URL;
1617
import java.security.AccessController;
1718
import java.security.PrivilegedActionException;
@@ -274,14 +275,25 @@ public void onClosed(EventSource eventSource) {
274275
*/
275276
@Override
276277
public void onFailure(EventSource eventSource, Throwable t, Response response) {
277-
log.error("SSE failure.");
278278
if (t != null) {
279+
// Network/connection error
279280
log.error("Error: " + t.getMessage(), t);
280281
if (t instanceof StreamResetException && t.getMessage().contains("NO_ERROR")) {
281282
// TODO: reconnect
282283
} else {
283-
streamActionListener.onFailure(new MLException("SSE failure.", t));
284+
streamActionListener.onFailure(new MLException("SSE failure with network error", t));
284285
}
286+
} else if (response != null) {
287+
// HTTP error (e.g., 400 Bad Request)
288+
try {
289+
String errorBody = response.body() != null ? response.body().string() : "";
290+
streamActionListener.onFailure(new MLException("Error from remote service: " + errorBody));
291+
} catch (IOException e) {
292+
streamActionListener.onFailure(new MLException("SSE failure - unable to read error details"));
293+
}
294+
} else {
295+
// Unknown failure
296+
streamActionListener.onFailure(new MLException("SSE failure"));
285297
}
286298
}
287299

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/StreamPredictActionListener.java

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,26 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
16
package org.opensearch.ml.engine.algorithms.remote;
27

3-
import java.io.IOException;
8+
import java.util.LinkedHashMap;
9+
import java.util.List;
10+
import java.util.Map;
411

512
import org.opensearch.core.action.ActionListener;
613
import org.opensearch.core.transport.TransportResponse;
14+
import org.opensearch.ml.common.output.model.ModelTensor;
15+
import org.opensearch.ml.common.output.model.ModelTensorOutput;
16+
import org.opensearch.ml.common.output.model.ModelTensors;
17+
import org.opensearch.ml.common.transport.MLTaskResponse;
718
import org.opensearch.transport.TransportChannel;
819
import org.opensearch.transport.TransportRequest;
920

21+
import lombok.extern.log4j.Log4j2;
22+
23+
@Log4j2
1024
public class StreamPredictActionListener<Response extends TransportResponse, Request extends TransportRequest>
1125
implements
1226
ActionListener<Response> {
@@ -46,10 +60,32 @@ public final void onResponse(Response response) {
4660
@Override
4761
public void onFailure(Exception e) {
4862
try {
49-
channel.sendResponse(e);
50-
} catch (IOException exc) {
63+
MLTaskResponse errorResponse = createErrorResponse(e);
64+
channel.sendResponseBatch(errorResponse);
5165
channel.completeStream();
52-
throw new RuntimeException(exc);
66+
} catch (Exception exc) {
67+
try {
68+
channel.completeStream();
69+
} catch (Exception streamException) {
70+
log.error("Failed to complete stream", streamException);
71+
}
72+
}
73+
}
74+
75+
private MLTaskResponse createErrorResponse(Exception error) {
76+
String errorMessage = error.getMessage();
77+
if (errorMessage == null || errorMessage.trim().isEmpty()) {
78+
errorMessage = "Request failed";
5379
}
80+
81+
Map<String, Object> errorData = new LinkedHashMap<>();
82+
errorData.put("error", errorMessage);
83+
errorData.put("is_last", true);
84+
85+
ModelTensor errorTensor = ModelTensor.builder().name("error").dataAsMap(errorData).build();
86+
ModelTensors errorTensors = ModelTensors.builder().mlModelTensors(List.of(errorTensor)).build();
87+
ModelTensorOutput errorOutput = ModelTensorOutput.builder().mlModelOutputs(List.of(errorTensors)).build();
88+
89+
return MLTaskResponse.builder().output(errorOutput).build();
5490
}
5591
}

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize() {
264264
);
265265

266266
Mockito.verify(actionListener, times(0)).onFailure(any());
267-
Mockito.verify(executor, times(3)).preparePayloadAndInvoke(anyString(), any(), any(), any());
267+
Mockito.verify(executor, times(3)).preparePayloadAndInvoke(anyString(), any(), any(), any(), null);
268268
}
269269

270270
@Test

ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDisabled(
134134
Exception exception = Assert
135135
.assertThrows(
136136
IllegalArgumentException.class,
137-
() -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener)
137+
() -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener, null)
138138
);
139139
assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role");
140140
}
@@ -154,7 +154,7 @@ public void executePreparePayloadAndInvoke_SkipValidateMissingParameterEnabled()
154154
String actionType = inputDataSet.getActionType().toString();
155155
MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build();
156156

157-
executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener);
157+
executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener, null);
158158
Mockito
159159
.verify(executor, times(1))
160160
.invokeRemoteService(any(), any(), any(), argThat(argument -> argument.contains("You are a ${parameters.role}")), any(), any());
@@ -177,7 +177,7 @@ public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDefault()
177177
Exception exception = Assert
178178
.assertThrows(
179179
IllegalArgumentException.class,
180-
() -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener)
180+
() -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener, null)
181181
);
182182
assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role");
183183
}
@@ -209,7 +209,7 @@ public void executePreparePayloadAndInvoke_PassingParameter() {
209209
Exception exception = Assert
210210
.assertThrows(
211211
IllegalArgumentException.class,
212-
() -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener)
212+
() -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener, null)
213213
);
214214
assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role");
215215
}
@@ -234,7 +234,7 @@ public void executePreparePayloadAndInvoke_GetParamsIOException() throws Excepti
234234
.inputDataset(inputDataSet)
235235
.build();
236236

237-
executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener);
237+
executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener, null);
238238
verify(actionListener).onFailure(argThat(e -> e instanceof IOException && e.getMessage().contains("UT test IOException")));
239239
}
240240

0 commit comments

Comments
 (0)