diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index dbe02bf22a..831e94d15a 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -26,6 +26,7 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; +import org.apache.commons.lang3.StringUtils; import org.apache.commons.text.StringEscapeUtils; import org.apache.commons.text.StringSubstitutor; import org.opensearch.Version; @@ -56,6 +57,8 @@ public class HttpConnector extends AbstractConnector { public static final String PARAMETERS_FIELD = "parameters"; public static final String SERVICE_NAME_FIELD = "service_name"; public static final String REGION_FIELD = "region"; + // TODO: move the AgentUtils class from algorithm module to common module + public static final String LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS = "openai/v1/chat/completions"; // TODO: add RequestConfig like request time out, @@ -377,14 +380,14 @@ private boolean neededStreamParameterInPayload(Map parameters) { } String llmInterface = parameters.get("_llm_interface"); - if (llmInterface.isBlank()) { + if (StringUtils.isBlank(llmInterface)) { return false; } llmInterface = llmInterface.trim().toLowerCase(Locale.ROOT); llmInterface = StringEscapeUtils.unescapeJava(llmInterface); switch (llmInterface) { - case "openai/v1/chat/completions": + case LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS: return true; default: return false; diff --git a/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteStreamTaskAction.java b/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteStreamTaskAction.java new file mode 100644 index 0000000000..4c6be87ecc --- /dev/null +++ b/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteStreamTaskAction.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.common.transport.execute; + +import org.opensearch.action.ActionType; + +public class MLExecuteStreamTaskAction extends ActionType { + public static final MLExecuteStreamTaskAction INSTANCE = new MLExecuteStreamTaskAction(); + public static final String NAME = "cluster:admin/opensearch/ml/execute/stream"; + + private MLExecuteStreamTaskAction() { + super(NAME, MLExecuteTaskResponse::new); + } +} diff --git a/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java b/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java index d998ea71de..3ce3e807ca 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/execute/MLExecuteTaskRequest.java @@ -22,19 +22,27 @@ import org.opensearch.ml.common.MLCommonsClassLoader; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.transport.MLTaskRequest; +import org.opensearch.transport.TransportChannel; import lombok.AccessLevel; import lombok.Builder; import lombok.Getter; import lombok.NonNull; +import lombok.Setter; import lombok.ToString; import lombok.experimental.FieldDefaults; +import lombok.experimental.NonFinal; @Getter @FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) @ToString public class MLExecuteTaskRequest extends MLTaskRequest { + @Getter + @Setter + @NonFinal + private transient TransportChannel streamingChannel; + FunctionName functionName; Input input; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Executable.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Executable.java index addfdb4bc6..171aac8dc6 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Executable.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Executable.java @@ -9,12 +9,24 @@ import org.opensearch.ml.common.exception.ExecuteException; import org.opensearch.ml.common.input.Input; import org.opensearch.ml.common.output.Output; +import org.opensearch.transport.TransportChannel; public interface Executable { /** - * Execute algorithm with given input data. + * Execute algorithm with given input data (non-streaming). * @param input input data + * @param listener action listener */ - void execute(Input input, ActionListener listener) throws ExecuteException; + default void execute(Input input, ActionListener listener) throws ExecuteException { + execute(input, listener, null); + } + + /** + * Execute algorithm with given input data (streaming). + * @param input input data + * @param listener action listener + * @param channel transport channel + */ + default void execute(Input input, ActionListener listener, TransportChannel channel) {} } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java index d468833b53..05f97475de 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java @@ -26,6 +26,7 @@ import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.Output; import org.opensearch.ml.engine.encryptor.Encryptor; +import org.opensearch.transport.TransportChannel; import lombok.Getter; import lombok.extern.log4j.Log4j2; @@ -186,7 +187,7 @@ public MLOutput trainAndPredict(Input input) { return trainAndPredictable.trainAndPredict(mlInput); } - public void execute(Input input, ActionListener listener) throws Exception { + public void execute(Input input, ActionListener listener, TransportChannel channel) throws Exception { validateInput(input); if (input.getFunctionName() == FunctionName.METRICS_CORRELATION) { MLExecutable executable = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class); @@ -199,6 +200,10 @@ public void execute(Input input, ActionListener listener) throws Excepti if (executable == null) { throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName()); } + if (channel != null) { + executable.execute(input, listener, channel); + return; + } executable.execute(input, listener); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index dbce9b68c5..4c8ca942fe 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -80,6 +80,7 @@ import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.common.SdkClientUtils; import org.opensearch.search.fetch.subphase.FetchSourceContext; +import org.opensearch.transport.TransportChannel; import org.opensearch.transport.client.Client; import com.google.common.annotations.VisibleForTesting; @@ -143,7 +144,7 @@ public void onMultiTenancyEnabledChanged(boolean isEnabled) { } @Override - public void execute(Input input, ActionListener listener) { + public void execute(Input input, ActionListener listener, TransportChannel channel) { if (!(input instanceof AgentMLInput)) { throw new IllegalArgumentException("wrong input"); } @@ -271,7 +272,8 @@ public void execute(Input input, ActionListener listener) { isAsync, outputs, modelTensors, - mlAgent + mlAgent, + channel ); }, e -> { log.error("Failed to get existing interaction for regeneration", e); @@ -287,7 +289,8 @@ public void execute(Input input, ActionListener listener) { isAsync, outputs, modelTensors, - mlAgent + mlAgent, + channel ); } }, ex -> { @@ -318,7 +321,8 @@ public void execute(Input input, ActionListener listener) { outputs, modelTensors, listener, - createdMemory + createdMemory, + channel ), ex -> { log.error("Failed to find memory with memory_id: {}", memoryId, ex); @@ -329,7 +333,6 @@ public void execute(Input input, ActionListener listener) { return; } } - executeAgent( inputDataSet, mlTask, @@ -339,7 +342,8 @@ public void execute(Input input, ActionListener listener) { outputs, modelTensors, listener, - null + null, + channel ); } } catch (Exception e) { @@ -382,7 +386,8 @@ private void saveRootInteractionAndExecute( boolean isAsync, List outputs, List modelTensors, - MLAgent mlAgent + MLAgent mlAgent, + TransportChannel channel ) { String appType = mlAgent.getAppType(); String question = inputDataSet.getParameters().get(QUESTION); @@ -416,7 +421,8 @@ private void saveRootInteractionAndExecute( outputs, modelTensors, listener, - memory + memory, + channel ), e -> { log.error("Failed to regenerate for interaction {}", regenerateInteractionId, e); @@ -425,7 +431,18 @@ private void saveRootInteractionAndExecute( ) ); } else { - executeAgent(inputDataSet, mlTask, isAsync, memory.getConversationId(), mlAgent, outputs, modelTensors, listener, memory); + executeAgent( + inputDataSet, + mlTask, + isAsync, + memory.getConversationId(), + mlAgent, + outputs, + modelTensors, + listener, + memory, + channel + ); } }, ex -> { log.error("Failed to create parent interaction", ex); @@ -442,7 +459,8 @@ private void executeAgent( List outputs, List modelTensors, ActionListener listener, - ConversationIndexMemory memory + ConversationIndexMemory memory, + TransportChannel channel ) { String mcpConnectorConfigJSON = (mlAgent.getParameters() != null) ? mlAgent.getParameters().get(MCP_CONNECTORS_FIELD) : null; if (mcpConnectorConfigJSON != null && !mlFeatureEnabledSetting.isMcpConnectorEnabled()) { @@ -494,7 +512,7 @@ private void executeAgent( memory ); inputDataSet.getParameters().put(TASK_ID_FIELD, taskId); - mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener); + mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener, channel); }, e -> { log.error("Failed to create task for agent async execution", e); listener.onFailure(e); @@ -508,7 +526,7 @@ private void executeAgent( parentInteractionId, memory ); - mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener); + mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener, channel); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentRunner.java index fd3d48208d..dfc6099e5f 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentRunner.java @@ -9,6 +9,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.transport.TransportChannel; /** * Agent executor interface definition. Agent executor will be used by {@link MLAgentExecutor} to invoke agents. @@ -16,10 +17,21 @@ public interface MLAgentRunner { /** - * Function interface to execute agent. + * Function interface to execute agent (non-streaming) * @param mlAgent * @param params * @param listener */ - void run(MLAgent mlAgent, Map params, ActionListener listener); + default void run(MLAgent mlAgent, Map params, ActionListener listener) { + run(mlAgent, params, listener, null); + } + + /** + * Function interface to execute agent (streaming) + * @param mlAgent + * @param params + * @param listener + * @param channel + */ + void run(MLAgent mlAgent, Map params, ActionListener listener, TransportChannel channel); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java index 74a70f5ff0..7e1a4050bd 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunner.java @@ -57,13 +57,10 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; import org.opensearch.core.xcontent.NamedXContentRegistry; -import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.agent.LLMSpec; import org.opensearch.ml.common.agent.MLAgent; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.ml.common.conversation.Interaction; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; -import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -71,8 +68,6 @@ import org.opensearch.ml.common.spi.memory.Message; import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.transport.MLTaskResponse; -import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; -import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.function_calling.FunctionCalling; @@ -84,6 +79,7 @@ import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.ml.repackage.com.google.common.collect.Lists; import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.transport.TransportChannel; import org.opensearch.transport.client.Client; import com.google.common.annotations.VisibleForTesting; @@ -138,6 +134,7 @@ public class MLChatAgentRunner implements MLAgentRunner { private Map memoryFactoryMap; private SdkClient sdkClient; private Encryptor encryptor; + private StreamingWrapper streamingWrapper; public MLChatAgentRunner( Client client, @@ -160,7 +157,8 @@ public MLChatAgentRunner( } @Override - public void run(MLAgent mlAgent, Map inputParams, ActionListener listener) { + public void run(MLAgent mlAgent, Map inputParams, ActionListener listener, TransportChannel channel) { + this.streamingWrapper = new StreamingWrapper(channel, client); Map params = new HashMap<>(); if (mlAgent.getParameters() != null) { params.putAll(mlAgent.getParameters()); @@ -348,6 +346,7 @@ private void runReAct( functionCalling ); + streamingWrapper.fixInteractionRole(interactions); String thought = String.valueOf(modelOutput.get(THOUGHT)); String toolCallId = String.valueOf(modelOutput.get("tool_call_id")); String action = String.valueOf(modelOutput.get(ACTION)); @@ -487,6 +486,7 @@ private void runReAct( } sessionMsgAnswerBuilder.append(outputToOutputString(filteredOutput)); + streamingWrapper.sendToolResponse(outputToOutputString(output), sessionId, parentInteractionId); traceTensors .add( ModelTensors @@ -518,18 +518,8 @@ private void runReAct( ); return; } - - ActionRequest request = new MLPredictionTaskRequest( - llm.getModelId(), - RemoteInferenceMLInput - .builder() - .algorithm(FunctionName.REMOTE) - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build()) - .build(), - null, - tenantId - ); - client.execute(MLPredictionTaskAction.INSTANCE, request, (ActionListener) nextStepListener); + ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); + streamingWrapper.executeRequest(request, (ActionListener) nextStepListener); } }, e -> { log.error("Failed to run chat agent", e); @@ -540,17 +530,8 @@ private void runReAct( } } - ActionRequest request = new MLPredictionTaskRequest( - llm.getModelId(), - RemoteInferenceMLInput - .builder() - .algorithm(FunctionName.REMOTE) - .inputDataset(RemoteInferenceInputDataSet.builder().parameters(tmpParameters).build()) - .build(), - null, - tenantId - ); - client.execute(MLPredictionTaskAction.INSTANCE, request, firstListener); + ActionRequest request = streamingWrapper.createPredictionRequest(llm, tmpParameters, tenantId); + streamingWrapper.executeRequest(request, firstListener); } private static List createFinalAnswerTensors(List sessionId, List lastThought) { @@ -737,6 +718,9 @@ private void sendFinalAnswer( Map additionalInfo, String finalAnswer ) { + // Send completion chunk for streaming + streamingWrapper.sendCompletionChunk(sessionId, parentInteractionId); + if (conversationIndexMemory != null) { String copyOfFinalAnswer = finalAnswer; ActionListener saveTraceListener = ActionListener.wrap(r -> { @@ -770,7 +754,8 @@ private void sendFinalAnswer( saveTraceListener ); } else { - returnFinalResponse(sessionId, listener, parentInteractionId, verbose, cotModelTensors, additionalInfo, finalAnswer); + streamingWrapper + .sendFinalResponse(sessionId, listener, parentInteractionId, verbose, cotModelTensors, additionalInfo, finalAnswer); } } @@ -857,7 +842,7 @@ static Map constructLLMParams(LLMSpec llm, Map p return tmpParameters; } - private static void returnFinalResponse( + public static void returnFinalResponse( String sessionId, ActionListener listener, String parentInteractionId, diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java index e6685ca619..54d847b929 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLConversationalFlowAgentRunner.java @@ -51,6 +51,7 @@ import org.opensearch.ml.engine.memory.ConversationIndexMessage; import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.transport.TransportChannel; import org.opensearch.transport.client.Client; import lombok.Data; @@ -94,7 +95,7 @@ public MLConversationalFlowAgentRunner( } @Override - public void run(MLAgent mlAgent, Map params, ActionListener listener) { + public void run(MLAgent mlAgent, Map params, ActionListener listener, TransportChannel channel) { String appType = mlAgent.getAppType(); String memoryId = params.get(MLAgentExecutor.MEMORY_ID); String parentInteractionId = params.get(MLAgentExecutor.PARENT_INTERACTION_ID); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java index 4e29afc220..30725a8c47 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLFlowAgentRunner.java @@ -39,6 +39,7 @@ import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import org.opensearch.ml.repackage.com.google.common.collect.ImmutableMap; import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.transport.TransportChannel; import org.opensearch.transport.client.Client; import lombok.Data; @@ -81,7 +82,7 @@ public MLFlowAgentRunner( @SuppressWarnings("removal") @Override - public void run(MLAgent mlAgent, Map params, ActionListener listener) { + public void run(MLAgent mlAgent, Map params, ActionListener listener, TransportChannel channel) { List toolSpecs = getMlToolSpecs(mlAgent, params); StepListener firstStepListener = null; Tool firstTool = null; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java index 2cbebde104..b8b89d8aa2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLPlanExecuteAndReflectAgentRunner.java @@ -71,6 +71,7 @@ import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.memory.ConversationIndexMemory; import org.opensearch.remote.metadata.client.SdkClient; +import org.opensearch.transport.TransportChannel; import org.opensearch.transport.client.Client; import com.google.common.annotations.VisibleForTesting; @@ -270,7 +271,7 @@ void populatePrompt(Map allParams) { } @Override - public void run(MLAgent mlAgent, Map apiParams, ActionListener listener) { + public void run(MLAgent mlAgent, Map apiParams, ActionListener listener, TransportChannel channel) { Map allParams = new HashMap<>(); allParams.putAll(apiParams); allParams.putAll(mlAgent.getParameters()); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapper.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapper.java new file mode 100644 index 0000000000..beb5e60f53 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapper.java @@ -0,0 +1,137 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.returnFinalResponse; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import org.opensearch.action.ActionRequest; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.agent.LLMSpec; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionStreamTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.transport.TransportChannel; +import org.opensearch.transport.client.Client; + +import lombok.extern.log4j.Log4j2; + +@Log4j2 +public class StreamingWrapper { + private final TransportChannel channel; + private final boolean isStreaming; + private Client client; + + public StreamingWrapper(TransportChannel channel, org.opensearch.transport.client.Client client) { + this.channel = channel; + this.client = client; + this.isStreaming = (channel != null); + } + + public void fixInteractionRole(List interactions) { + if (isStreaming && !interactions.isEmpty()) { + try { + String lastInteraction = interactions.get(interactions.size() - 1); + Map messageMap = gson.fromJson(lastInteraction, Map.class); + + if (!messageMap.containsKey("role") && messageMap.containsKey("tool_calls")) { + messageMap.put("role", "assistant"); + interactions.set(interactions.size() - 1, StringUtils.toJson(messageMap)); + } + } catch (Exception e) { + log.error("Failed to fix assistant message role after parseLLMOutput", e); + } + } + } + + public ActionRequest createPredictionRequest(LLMSpec llm, Map parameters, String tenantId) { + return new MLPredictionTaskRequest( + llm.getModelId(), + RemoteInferenceMLInput + .builder() + .algorithm(FunctionName.REMOTE) + .inputDataset(RemoteInferenceInputDataSet.builder().parameters(parameters).build()) + .build(), + // TODO: handle agent streaming in multi-node + !isStreaming, // set dispatchTask to false for streaming + null, + tenantId + ); + } + + public void executeRequest(ActionRequest request, ActionListener listener) { + if (isStreaming) { + ((MLPredictionTaskRequest) request).setStreamingChannel(channel); + client.execute(MLPredictionStreamTaskAction.INSTANCE, request, listener); + return; + } + client.execute(MLPredictionTaskAction.INSTANCE, request, listener); + } + + public void sendCompletionChunk(String sessionId, String parentInteractionId) { + if (!isStreaming) { + return; + } + MLTaskResponse completionChunk = createStreamChunk("", sessionId, parentInteractionId, true); + try { + channel.sendResponseBatch(completionChunk); + } catch (Exception e) { + log.warn("Failed to send completion chunk: {}", e.getMessage()); + } + } + + public void sendFinalResponse( + String sessionId, + ActionListener listener, + String parentInteractionId, + boolean verbose, + List cotModelTensors, + Map additionalInfo, + String finalAnswer + ) { + if (isStreaming) { + listener.onResponse("Streaming completed"); + } else { + returnFinalResponse(sessionId, listener, parentInteractionId, verbose, cotModelTensors, additionalInfo, finalAnswer); + } + } + + public void sendToolResponse(String toolOutput, String sessionId, String parentInteractionId) { + if (isStreaming) { + try { + MLTaskResponse toolChunk = createStreamChunk(toolOutput, sessionId, parentInteractionId, false); + channel.sendResponseBatch(toolChunk); + } catch (Exception e) { + log.error("Failed to send tool response chunk", e); + } + } + } + + private MLTaskResponse createStreamChunk(String toolOutput, String sessionId, String parentInteractionId, boolean isLast) { + List tensors = Arrays + .asList( + ModelTensor.builder().name("response").dataAsMap(Map.of("content", toolOutput, "is_last", isLast)).build(), + ModelTensor.builder().name("memory_id").result(sessionId).build(), + ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build() + ); + + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(tensors).build(); + ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(List.of(modelTensors)).build(); + return new MLTaskResponse(output); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutor.java index a4f5ec4fe4..46c653776d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutor.java @@ -5,17 +5,8 @@ package org.opensearch.ml.engine.algorithms.remote; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; - import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.ConnectorClientConfig; -import org.opensearch.ml.common.output.model.ModelTensor; -import org.opensearch.ml.common.output.model.ModelTensorOutput; -import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.transport.MLTaskResponse; import lombok.Getter; import lombok.Setter; @@ -32,23 +23,4 @@ public void initialize(Connector connector) { connectorClientConfig = new ConnectorClientConfig(); } } - - public void sendContentResponse(String content, boolean isLast, StreamPredictActionListener actionListener) { - List modelTensors = new ArrayList<>(); - Map dataMap = Map.of("content", content, "is_last", isLast); - - modelTensors.add(ModelTensor.builder().name("response").dataAsMap(dataMap).build()); - ModelTensorOutput output = ModelTensorOutput - .builder() - .mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(modelTensors).build())) - .build(); - MLTaskResponse response = MLTaskResponse.builder().output(output).build(); - actionListener.onStreamResponse(response, isLast); - } - - public void sendCompletionResponse(AtomicBoolean isStreamClosed, StreamPredictActionListener actionListener) { - if (isStreamClosed.compareAndSet(false, true)) { - sendContentResponse("", true, actionListener); - } - } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 484058d550..70673870ce 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -18,7 +18,6 @@ import java.util.Locale; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.atomic.AtomicBoolean; import org.apache.commons.text.StringEscapeUtils; import org.apache.logging.log4j.Logger; @@ -33,6 +32,9 @@ import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener; +import org.opensearch.ml.engine.algorithms.remote.streaming.StreamingHandler; +import org.opensearch.ml.engine.algorithms.remote.streaming.StreamingHandlerFactory; import org.opensearch.ml.engine.annotation.ConnectorExecutor; import org.opensearch.script.ScriptService; import org.opensearch.transport.StreamTransportService; @@ -41,20 +43,10 @@ import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; -import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; -import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; -import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.async.AsyncExecuteRequest; import software.amazon.awssdk.http.async.SdkAsyncHttpClient; -import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; -import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; -import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest; -import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler; -import software.amazon.awssdk.services.bedrockruntime.model.Message; @Log4j2 @ConnectorExecutor(AWS_SIGV4) @@ -80,8 +72,6 @@ public class AwsConnectorExecutor extends AbstractConnectorExecutor { private SdkAsyncHttpClient httpClient; - private BedrockRuntimeAsyncClient bedrockRuntimeAsyncClient; - @Setter @Getter private StreamTransportService streamTransportService; @@ -93,7 +83,6 @@ public AwsConnectorExecutor(Connector connector) { Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); Integer maxConnection = super.getConnectorClientConfig().getMaxConnections(); this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection); - this.bedrockRuntimeAsyncClient = null; } @Override @@ -160,73 +149,19 @@ public void invokeRemoteServiceStream( StreamPredictActionListener actionListener ) { try { - AtomicBoolean isStreamClosed = new AtomicBoolean(false); String llmInterface = parameters.get(LLM_INTERFACE); llmInterface = llmInterface.trim().toLowerCase(Locale.ROOT); llmInterface = StringEscapeUtils.unescapeJava(llmInterface); validateLLMInterface(llmInterface); - ConverseStreamRequest request = ConverseStreamRequest - .builder() - .modelId(parameters.get("model")) - .messages(Message.builder().role("user").content(ContentBlock.builder().text(parameters.get("inputs")).build()).build()) - .build(); - - ConverseStreamResponseHandler handler = ConverseStreamResponseHandler.builder().onResponse(response -> { - // Handle initial response - log.debug("Initial converse stream response: {}", response); - }).onError(error -> { - // Handle errors - log.error("Converse stream error: {}", error.getMessage()); - actionListener.onFailure(new MLException("Error from remote service: " + error.getMessage(), error)); - }).onComplete(() -> { - // Handle completion - log.debug("Converse stream complete"); - sendCompletionResponse(isStreamClosed, actionListener); - }).subscriber(event -> { - log.debug("Converse stream event: {}", event); - switch (event.sdkEventType()) { - case CONTENT_BLOCK_DELTA: - ContentBlockDeltaEvent contentEvent = (ContentBlockDeltaEvent) event; - String chunk = contentEvent.delta().text(); - sendContentResponse(chunk, false, actionListener); - break; - default: - // Ignore the other event types for now. - break; - } - }).build(); - if (bedrockRuntimeAsyncClient == null) { - bedrockRuntimeAsyncClient = buildBedrockRuntimeAsyncClient(httpClient); - } - bedrockRuntimeAsyncClient.converseStream(request, handler); + StreamingHandler handler = StreamingHandlerFactory.createHandler(llmInterface, connector, httpClient, null); + handler.startStream(action, parameters, payload, actionListener); } catch (Exception e) { log.error("Failed to execute streaming", e); actionListener.onFailure(new MLException("Fail to execute streaming", e)); } } - private BedrockRuntimeAsyncClient buildBedrockRuntimeAsyncClient(SdkAsyncHttpClient sdkAsyncHttpClient) { - AwsCredentialsProvider awsCredentialsProvider; - if (connector.getSessionToken() != null) { - AwsSessionCredentials credentials = AwsSessionCredentials - .create(connector.getAccessKey(), connector.getSecretKey(), connector.getSessionToken()); - awsCredentialsProvider = StaticCredentialsProvider.create(credentials); - } else { - awsCredentialsProvider = StaticCredentialsProvider - .create( - software.amazon.awssdk.auth.credentials.AwsBasicCredentials.create(connector.getAccessKey(), connector.getSecretKey()) - ); - } - - return BedrockRuntimeAsyncClient - .builder() - .region(Region.of(connector.getRegion())) - .credentialsProvider(awsCredentialsProvider) - .httpClient(sdkAsyncHttpClient) - .build(); - } - private SdkHttpFullRequest signRequest(SdkHttpFullRequest request) { String accessKey = connector.getAccessKey(); String secretKey = connector.getSecretKey(); @@ -242,7 +177,7 @@ private void validateLLMInterface(String llmInterface) { case LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE: break; default: - throw new MLException(String.format("Unsupported llm interface: %s", llmInterface)); + throw new IllegalArgumentException(String.format("Unsupported llm interface: %s", llmInterface)); } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java index 7be072f970..cc0736c9e2 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtils.java @@ -386,7 +386,7 @@ public static SdkHttpFullRequest buildSdkRequest( return builder.build(); } - public static Request buildOKHttpRequestPOST(String action, Connector connector, Map parameters, String payload) { + public static Request buildOKHttpStreamingRequest(String action, Connector connector, Map parameters, String payload) { okhttp3.RequestBody requestBody; if (payload != null) { requestBody = okhttp3.RequestBody.create(payload, MediaType.parse("application/json; charset=utf-8")); @@ -395,6 +395,18 @@ public static Request buildOKHttpRequestPOST(String action, Connector connector, } String endpoint = connector.getActionEndpoint(action, parameters); + URI uri; + try { + uri = URI.create(endpoint); + if (uri.getHost() == null) { + throw new IllegalArgumentException("Invalid URI" + ". Please check if the endpoint is valid from connector."); + } + } catch (Exception e) { + throw new IllegalArgumentException( + "Encountered error when trying to create uri from endpoint in ml connector. Please update the endpoint in connection configuration: ", + e + ); + } Request.Builder requestBuilder = new Request.Builder(); Map headers = connector.getDecryptedHeaders(); if (headers != null) { @@ -402,6 +414,8 @@ public static Request buildOKHttpRequestPOST(String action, Connector connector, requestBuilder.addHeader(key, headers.get(key)); } } + + // Add SSE-specific headers requestBuilder.addHeader("Accept-Encoding", ""); requestBuilder.addHeader("Accept", "text/event-stream"); requestBuilder.addHeader("Cache-Control", "no-cache"); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index 6a408ba61d..b8984a1246 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -8,20 +8,16 @@ import static org.opensearch.ml.common.connector.ConnectorProtocols.HTTP; import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS; import static org.opensearch.ml.engine.algorithms.agent.MLChatAgentRunner.LLM_INTERFACE; -import static org.opensearch.ml.engine.function_calling.OpenaiV1ChatCompletionsFunctionCalling.FINISH_REASON_PATH; import static software.amazon.awssdk.http.SdkHttpMethod.GET; import static software.amazon.awssdk.http.SdkHttpMethod.POST; -import java.io.IOException; import java.net.URL; import java.security.AccessController; -import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.time.Duration; import java.util.Locale; import java.util.Map; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import org.apache.commons.text.StringEscapeUtils; @@ -37,24 +33,17 @@ import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; -import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener; +import org.opensearch.ml.engine.algorithms.remote.streaming.StreamingHandler; +import org.opensearch.ml.engine.algorithms.remote.streaming.StreamingHandlerFactory; import org.opensearch.ml.engine.annotation.ConnectorExecutor; import org.opensearch.script.ScriptService; import org.opensearch.transport.StreamTransportService; import org.opensearch.transport.client.Client; -import com.jayway.jsonpath.JsonPath; - import lombok.Getter; import lombok.Setter; import lombok.extern.log4j.Log4j2; -import okhttp3.OkHttpClient; -import okhttp3.Request; -import okhttp3.Response; -import okhttp3.internal.http2.StreamResetException; -import okhttp3.sse.EventSource; -import okhttp3.sse.EventSourceListener; -import okhttp3.sse.EventSources; import software.amazon.awssdk.core.internal.http.async.SimpleHttpContentPublisher; import software.amazon.awssdk.http.SdkHttpFullRequest; import software.amazon.awssdk.http.async.AsyncExecuteRequest; @@ -87,7 +76,6 @@ public class HttpJsonConnectorExecutor extends AbstractConnectorExecutor { private SdkAsyncHttpClient httpClient; - private OkHttpClient okHttpClient; @Setter @Getter private StreamTransportService streamTransportService; @@ -99,18 +87,6 @@ public HttpJsonConnectorExecutor(Connector connector) { Duration readTimeout = Duration.ofSeconds(super.getConnectorClientConfig().getReadTimeout()); Integer maxConnection = super.getConnectorClientConfig().getMaxConnections(); this.httpClient = MLHttpClientFactory.getAsyncHttpClient(connectionTimeout, readTimeout, maxConnection); - try { - AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - this.okHttpClient = new OkHttpClient.Builder() - .connectTimeout(10, TimeUnit.SECONDS) - .readTimeout(1, TimeUnit.MINUTES) - .retryOnConnectionFailure(true) - .build(); - return null; - }); - } catch (PrivilegedActionException e) { - throw new RuntimeException("Failed to build OkHttpClient.", e); - } } @Override @@ -184,15 +160,9 @@ public void invokeRemoteServiceStream( llmInterface = StringEscapeUtils.unescapeJava(llmInterface); validateLLMInterface(llmInterface); - log.info("Creating SSE connection for streaming request"); - EventSourceListener listener = new HTTPEventSourceListener(actionListener, llmInterface); - Request request = ConnectorUtils.buildOKHttpRequestPOST(action, connector, parameters, payload); - - AccessController.doPrivileged((PrivilegedExceptionAction) () -> { - EventSources.createFactory(okHttpClient).newEventSource(request, listener); - return null; - }); - + StreamingHandler handler = StreamingHandlerFactory + .createHandler(llmInterface, connector, null, super.getConnectorClientConfig()); + handler.startStream(action, parameters, payload, actionListener); } catch (Exception e) { log.error("Failed to execute streaming", e); actionListener.onFailure(new MLException("Fail to execute streaming", e)); @@ -213,106 +183,7 @@ private void validateLLMInterface(String llmInterface) { case LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS: break; default: - throw new MLException(String.format("Unsupported llm interface: %s", llmInterface)); - } - } - - public final class HTTPEventSourceListener extends EventSourceListener { - private StreamPredictActionListener streamActionListener; - private final String llmInterface; - private volatile AtomicBoolean isStreamClosed; - - public HTTPEventSourceListener(StreamPredictActionListener streamActionListener, String llmInterface) { - this.streamActionListener = streamActionListener; - this.llmInterface = llmInterface; - this.isStreamClosed = new AtomicBoolean(false); - } - - /*** - * Callback when the SSE endpoint connection is made. - * @param eventSource the event source - * @param response the response - */ - @Override - public void onOpen(EventSource eventSource, Response response) { - log.debug("Connected to SSE Endpoint."); - } - - /*** - * For each event received from the SSE endpoint - * @param eventSource The event source - * @param id The id of the event - * @param type The type of the event which is used to filter - * @param data The event data - */ - @Override - public void onEvent(EventSource eventSource, String id, String type, String data) { - log.debug("The data is: {}", data); - switch (llmInterface) { - case LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS: - onOpenAIEvent(data); - break; - default: - throw new MLException(String.format("Unsupported llm interface: %s", llmInterface)); - } - } - - /*** - * When the connection is closed we receive this even which is currently only logged. - * @param eventSource The event source - */ - @Override - public void onClosed(EventSource eventSource) { - log.debug("SSE CLOSED."); - } - - /*** - * If there is any failure we log the error and the stack trace - * During stream resets with no errors we set the connected flag to false to allow the main thread to attempt a re-connect - * @param eventSource The event source - * @param t The error object - * @param response The response - */ - @Override - public void onFailure(EventSource eventSource, Throwable t, Response response) { - if (t != null) { - // Network/connection error - log.error("Error: " + t.getMessage(), t); - if (t instanceof StreamResetException && t.getMessage().contains("NO_ERROR")) { - // TODO: reconnect - } else { - streamActionListener.onFailure(new MLException("SSE failure with network error", t)); - } - } else if (response != null) { - // HTTP error (e.g., 400 Bad Request) - try { - String errorBody = response.body() != null ? response.body().string() : ""; - streamActionListener.onFailure(new MLException("Error from remote service: " + errorBody)); - } catch (IOException e) { - streamActionListener.onFailure(new MLException("SSE failure - unable to read error details")); - } - } else { - // Unknown failure - streamActionListener.onFailure(new MLException("SSE failure")); - } - } - - private void onOpenAIEvent(String data) { - if (data.contentEquals("[DONE]")) { - sendCompletionResponse(isStreamClosed, streamActionListener); - return; - } - Map dataMap = StringUtils.fromJson(data, "data"); - String llmFinishReason = JsonPath.read(dataMap, FINISH_REASON_PATH); - if (llmFinishReason != null && llmFinishReason.contentEquals("stop")) { - sendCompletionResponse(isStreamClosed, streamActionListener); - return; - } - String deltaContent = JsonPath.read(dataMap, "$.choices[0].delta.content"); - if (deltaContent != null && !deltaContent.isEmpty()) { - log.debug("Streaming content: {}", deltaContent); - sendContentResponse(deltaContent, false, streamActionListener); - } + throw new IllegalArgumentException(String.format("Unsupported llm interface: %s", llmInterface)); } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutor.java index bf7930265b..181296e335 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpConnectorExecutor.java @@ -36,6 +36,7 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener; import org.opensearch.ml.engine.annotation.ConnectorExecutor; import org.opensearch.ml.engine.tools.McpSseTool; import org.opensearch.script.ScriptService; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpStreamableHttpConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpStreamableHttpConnectorExecutor.java index 1fef7612b9..d73e294785 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpStreamableHttpConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/McpStreamableHttpConnectorExecutor.java @@ -38,6 +38,7 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener; import org.opensearch.ml.engine.annotation.ConnectorExecutor; import org.opensearch.ml.engine.tools.McpStreamableHttpTool; import org.opensearch.script.ScriptService; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java index 9b194328bb..49b25c16d7 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutor.java @@ -53,6 +53,7 @@ import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener; import org.opensearch.script.ScriptService; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportChannel; @@ -71,17 +72,12 @@ default void executeAction(String action, MLInput mlInput, ActionListener actionListener, TransportChannel channel) { // Check for streaming if (channel != null) { - preparePayloadAndInvoke(action, mlInput, new ExecutionContext(0), new ActionListener>() { - @Override - public void onResponse(Tuple response) { - actionListener.onResponse(new MLTaskResponse(new ModelTensorOutput(Arrays.asList(response.v2())))); - } - - @Override - public void onFailure(Exception e) { - actionListener.onFailure(e); - } - }, channel); + ActionListener> streamingListener = ActionListener.wrap(response -> { + ModelTensors tensors = response.v2(); + MLTaskResponse mlResponse = new MLTaskResponse(new ModelTensorOutput(Arrays.asList(tensors))); + actionListener.onResponse(mlResponse); + }, actionListener::onFailure); + preparePayloadAndInvoke(action, mlInput, new ExecutionContext(0), streamingListener, actionListener, channel); return; } @@ -115,18 +111,11 @@ public void onFailure(Exception e) { .inputDataset(TextDocsInputDataSet.builder().docs(textDocs).build()) .build(), new ExecutionContext(sequence++), - groupedActionListener, - null + groupedActionListener ); } } else { - preparePayloadAndInvoke( - action, - mlInput, - new ExecutionContext(0), - new GroupedActionListener<>(tensorActionListener, 1), - null - ); + preparePayloadAndInvoke(action, mlInput, new ExecutionContext(0), new GroupedActionListener<>(tensorActionListener, 1)); } } catch (Exception e) { actionListener.onFailure(e); @@ -204,11 +193,21 @@ default void setUserRateLimiterMap(Map userRateLimiterMap) default void setMlGuard(MLGuard mlGuard) {} + default void preparePayloadAndInvoke( + String action, + MLInput mlInput, + ExecutionContext executionContext, + ActionListener> actionListener + ) { + preparePayloadAndInvoke(action, mlInput, executionContext, actionListener, null, null); + } + default void preparePayloadAndInvoke( String action, MLInput mlInput, ExecutionContext executionContext, ActionListener> actionListener, + ActionListener agentListener, TransportChannel channel ) { Connector connector = getConnector(); @@ -271,7 +270,16 @@ && getUserRateLimiterMap().get(user.getName()) != null if (getConnectorClientConfig().getMaxRetryTimes() != 0) { invokeRemoteServiceWithRetry(action, mlInput, parameters, payload, executionContext, actionListener); } else if (parameters.containsKey("stream")) { - StreamPredictActionListener streamListener = new StreamPredictActionListener<>(channel); + String memoryId = parameters.get("memory_id"); + String parentInteractionId = parameters.get("parent_interaction_id"); + // TODO: find a better way to differentiate agent and predict request + boolean isAgentRequest = (memoryId != null || parentInteractionId != null); + StreamPredictActionListener streamListener = new StreamPredictActionListener<>( + channel, + isAgentRequest ? agentListener : null, + memoryId, + parentInteractionId + ); invokeRemoteServiceStream(action, mlInput, parameters, payload, executionContext, streamListener); } else { invokeRemoteService(action, mlInput, parameters, payload, executionContext, actionListener); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BaseStreamingHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BaseStreamingHandler.java new file mode 100644 index 0000000000..73a4f740ed --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BaseStreamingHandler.java @@ -0,0 +1,38 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.remote.streaming; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; + +public abstract class BaseStreamingHandler implements StreamingHandler { + + protected void sendContentResponse(String content, boolean isLast, StreamPredictActionListener actionListener) { + List modelTensors = new ArrayList<>(); + Map dataMap = Map.of("content", content, "is_last", isLast); + + modelTensors.add(ModelTensor.builder().name("response").dataAsMap(dataMap).build()); + ModelTensorOutput output = ModelTensorOutput + .builder() + .mlModelOutputs(List.of(ModelTensors.builder().mlModelTensors(modelTensors).build())) + .build(); + MLTaskResponse response = MLTaskResponse.builder().output(output).build(); + actionListener.onStreamResponse(response, isLast); + } + + protected void sendCompletionResponse(AtomicBoolean isStreamClosed, StreamPredictActionListener actionListener) { + if (isStreamClosed.compareAndSet(false, true)) { + sendContentResponse("", true, actionListener); + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BedrockStreamingHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BedrockStreamingHandler.java new file mode 100644 index 0000000000..0ec9cce537 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/BedrockStreamingHandler.java @@ -0,0 +1,510 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.remote.streaming; + +import static org.opensearch.ml.common.CommonValue.REMOTE_SERVICE_ERROR; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; + +import javax.naming.AuthenticationException; + +import org.opensearch.OpenSearchStatusException; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.ml.common.connector.AwsConnector; +import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.engine.algorithms.remote.RemoteConnectorThrottlingException; + +import com.fasterxml.jackson.core.JsonParseException; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import lombok.extern.log4j.Log4j2; +import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; +import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; +import software.amazon.awssdk.core.document.Document; +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockDeltaEvent; +import software.amazon.awssdk.services.bedrockruntime.model.ContentBlockStartEvent; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamOutput; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamRequest; +import software.amazon.awssdk.services.bedrockruntime.model.ConverseStreamResponseHandler; +import software.amazon.awssdk.services.bedrockruntime.model.Message; +import software.amazon.awssdk.services.bedrockruntime.model.SystemContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.Tool; +import software.amazon.awssdk.services.bedrockruntime.model.ToolConfiguration; +import software.amazon.awssdk.services.bedrockruntime.model.ToolInputSchema; +import software.amazon.awssdk.services.bedrockruntime.model.ToolResultBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ToolResultContentBlock; +import software.amazon.awssdk.services.bedrockruntime.model.ToolSpecification; +import software.amazon.awssdk.services.bedrockruntime.model.ValidationException; +import software.amazon.awssdk.services.s3.model.InvalidRequestException; + +@Log4j2 +public class BedrockStreamingHandler extends BaseStreamingHandler { + + private final SdkAsyncHttpClient httpClient; + private final AwsConnector connector; + private static final String STOP_REASON_TOOL_USE = "StopReason=tool_use"; + + private enum StreamState { + STREAMING_CONTENT, + TOOL_CALL_DETECTED, + ACCUMULATING_TOOL_INPUT, + WAITING_FOR_TOOL_RESULT, + COMPLETED + } + + public BedrockStreamingHandler(SdkAsyncHttpClient httpClient, AwsConnector connector) { + this.httpClient = httpClient; + this.connector = connector; + } + + @Override + public void startStream( + String action, + Map parameters, + String payload, + StreamPredictActionListener listener + ) { + try { + AtomicBoolean isStreamClosed = new AtomicBoolean(false); + AtomicReference toolName = new AtomicReference<>(); + AtomicReference> toolInput = new AtomicReference<>(); + AtomicReference toolUseId = new AtomicReference<>(); + StringBuilder toolInputAccumulator = new StringBuilder(); + AtomicReference currentState = new AtomicReference<>(StreamState.STREAMING_CONTENT); + + // Build Bedrock client + BedrockRuntimeAsyncClient bedrockClient = buildBedrockRuntimeAsyncClient(); + + // Parse payload to build ConverseStreamRequest + ConverseStreamRequest request = buildConverseStreamRequest(payload, parameters); + + ConverseStreamResponseHandler handler = ConverseStreamResponseHandler.builder().onResponse(response -> {}).onError(error -> { + log.error("Converse stream error: {}", error.getMessage()); + if (isThrottlingError(error)) { + listener + .onFailure( + new RemoteConnectorThrottlingException( + REMOTE_SERVICE_ERROR + + "The request was denied due to remote server throttling. " + + "To change the retry policy and behavior, please update the connector client_config.", + RestStatus.BAD_REQUEST + ) + ); + } else if (isClientError(error)) { + // 4XX errors + listener.onFailure(new OpenSearchStatusException(REMOTE_SERVICE_ERROR + error.getMessage(), RestStatus.BAD_REQUEST)); + } else { + // 5xx errors + listener.onFailure(new MLException(REMOTE_SERVICE_ERROR + error.getMessage(), error)); + } + }).onComplete(() -> { + if (currentState.get() != StreamState.WAITING_FOR_TOOL_RESULT) { + sendCompletionResponse(isStreamClosed, listener); + } else { + log.debug("Tool execution in progress - keeping stream open"); + } + }).subscriber(event -> { + handleStreamEvent(event, listener, isStreamClosed, toolName, toolInput, toolUseId, toolInputAccumulator, currentState); + }).build(); + + // Start streaming + bedrockClient.converseStream(request, handler); + } catch (Exception e) { + log.error("Failed to execute Bedrock streaming", e); + handleError(e, listener); + } + } + + @Override + public void handleError(Throwable error, StreamPredictActionListener listener) { + log.error("HTTP streaming error", error); + listener.onFailure(new MLException("Fail to execute streaming", error)); + } + + private boolean isThrottlingError(Throwable error) { + return error.getMessage().contains("throttling") + || error.getMessage().contains("TooManyRequestsException") + || error.getMessage().contains("Rate exceeded"); + } + + private boolean isClientError(Throwable error) { + return error instanceof ValidationException || error instanceof InvalidRequestException || error instanceof AuthenticationException; + } + + private ConverseStreamRequest buildConverseStreamRequest(String payload, Map parameters) { + try { + ObjectMapper mapper = new ObjectMapper(); + JsonNode payloadJson = mapper.readTree(payload); + return ConverseStreamRequest + .builder() + .modelId(parameters.get("model")) + .system(getOptionalNode(payloadJson, "system").map(this::parseSystemMessages).orElse(null)) + .messages(getOptionalNode(payloadJson, "messages").map(this::parseMessages).orElse(null)) + .toolConfig(getOptionalNode(payloadJson, "toolConfig").map(this::parseToolConfig).orElse(null)) + .build(); + } catch (Exception e) { + throw new MLException("Failed to parse payload for Bedrock request", e); + } + } + + private Optional getOptionalNode(JsonNode json, String field) { + return Optional.ofNullable(json.get(field)); + } + + private void handleStreamEvent( + ConverseStreamOutput event, + StreamPredictActionListener listener, + AtomicBoolean isStreamClosed, + AtomicReference toolName, + AtomicReference> toolInput, + AtomicReference toolUseId, + StringBuilder toolInputAccumulator, + AtomicReference currentState + ) { + switch (currentState.get()) { + case STREAMING_CONTENT: + if (isToolUseDetected(event)) { + currentState.set(StreamState.TOOL_CALL_DETECTED); + extractToolInfo(event, toolName, toolUseId); + } else if (isContentDelta(event)) { + sendContentResponse(getTextContent(event), false, listener); + } else if (isStreamComplete(event)) { + currentState.set(StreamState.COMPLETED); + sendCompletionResponse(isStreamClosed, listener); + } + break; + + case TOOL_CALL_DETECTED: + if (isToolInputDelta(event)) { + currentState.set(StreamState.ACCUMULATING_TOOL_INPUT); + accumulateToolInput(getToolInputFragment(event), toolInput, toolInputAccumulator); + } + break; + + case ACCUMULATING_TOOL_INPUT: + if (isToolInputDelta(event)) { + accumulateToolInput(getToolInputFragment(event), toolInput, toolInputAccumulator); + } else if (isToolInputComplete(event)) { + currentState.set(StreamState.WAITING_FOR_TOOL_RESULT); + listener.onResponse(createToolUseResponse(toolName, toolInput, toolUseId)); + } + break; + + case WAITING_FOR_TOOL_RESULT: + // Don't close stream - wait for tool execution + log.debug("Waiting for tool result - keeping stream open"); + break; + + case COMPLETED: + // Stream already completed + break; + } + } + + // TODO: refactor the event type checker methods + private void extractToolInfo(ConverseStreamOutput event, AtomicReference toolName, AtomicReference toolUseId) { + ContentBlockStartEvent startEvent = (ContentBlockStartEvent) event; + if (startEvent.start() != null && startEvent.start().toolUse() != null) { + toolName.set(startEvent.start().toolUse().name()); + toolUseId.set(startEvent.start().toolUse().toolUseId()); + } + } + + private String getTextContent(ConverseStreamOutput event) { + ContentBlockDeltaEvent contentEvent = (ContentBlockDeltaEvent) event; + return contentEvent.delta().text(); + } + + private String getToolInputFragment(ConverseStreamOutput event) { + ContentBlockDeltaEvent contentEvent = (ContentBlockDeltaEvent) event; + return contentEvent.delta().toolUse().input(); + } + + private boolean isToolUseDetected(ConverseStreamOutput event) { + return event.sdkEventType() == ConverseStreamOutput.EventType.CONTENT_BLOCK_START; + } + + private boolean isContentDelta(ConverseStreamOutput event) { + return event.sdkEventType() == ConverseStreamOutput.EventType.CONTENT_BLOCK_DELTA + && ((ContentBlockDeltaEvent) event).delta().text() != null; + } + + private boolean isToolInputDelta(ConverseStreamOutput event) { + return event.sdkEventType() == ConverseStreamOutput.EventType.CONTENT_BLOCK_DELTA + && ((ContentBlockDeltaEvent) event).delta().toolUse() != null; + } + + private boolean isStreamComplete(ConverseStreamOutput event) { + return event.sdkEventType() == ConverseStreamOutput.EventType.MESSAGE_STOP && !event.toString().contains(STOP_REASON_TOOL_USE); + } + + private boolean isToolInputComplete(ConverseStreamOutput event) { + return event.sdkEventType() == ConverseStreamOutput.EventType.MESSAGE_STOP && event.toString().contains(STOP_REASON_TOOL_USE); + } + + private MLTaskResponse createToolUseResponse( + AtomicReference toolName, + AtomicReference> toolInput, + AtomicReference toolUseId + ) { + // Validate inputs + if (toolName == null || toolInput == null || toolUseId == null) { + throw new IllegalArgumentException("Tool references cannot be null"); + } + Map wrappedResponse = Map + .of( + "output", + Map + .of( + "message", + Map + .of( + "content", + List + .of( + Map + .of( + "toolUse", + Map.of("name", toolName.get(), "input", toolInput.get(), "toolUseId", toolUseId.get()) + ) + ) + ) + ), + "stopReason", + "tool_use" + ); + + ModelTensor tensor = ModelTensor.builder().name("response").dataAsMap(wrappedResponse).build(); + ModelTensors tensors = ModelTensors.builder().mlModelTensors(List.of(tensor)).build(); + ModelTensorOutput output = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build(); + return new MLTaskResponse(output); + } + + private void accumulateToolInput( + String inputFragment, + AtomicReference> toolInput, + StringBuilder toolInputAccumulator + ) { + if (inputFragment == null) { + return; + } + ObjectMapper objectMapper = new ObjectMapper(); + toolInputAccumulator.append(inputFragment); + String accumulated = toolInputAccumulator.toString(); + + try { + JsonParser parser = objectMapper.getFactory().createParser(accumulated); + JsonToken firstToken = parser.nextToken(); + + // Check if it starts with an object + if (firstToken != JsonToken.START_OBJECT) { + log.debug("Input does not start with an object: {}", accumulated); + return; + } + + // Parse through the entire structure + int objectDepth = 1; + while (parser.nextToken() != null) { + JsonToken currentToken = parser.getCurrentToken(); + if (currentToken == JsonToken.START_OBJECT) { + objectDepth++; + } else if (currentToken == JsonToken.END_OBJECT) { + objectDepth--; + } + + // Check if a complete object is found + if (objectDepth == 0) { + // Check if there's any remaining content + if (parser.nextToken() != null) { + log.debug("Extra content after JSON object: {}", accumulated); + return; + } + + // Valid and complete JSON object found + Map parsedInput = objectMapper.readValue(accumulated, Map.class); + toolInput.set(parsedInput); + log.debug("Successfully parsed tool input: {}", parsedInput); + return; + } + } + + // JSON is incomplete + log.debug("Incomplete JSON object: {}", accumulated); + + } catch (JsonParseException e) { + log.debug("Invalid or incomplete JSON: {}", accumulated); + } catch (IOException e) { + log.error("Error parsing JSON input", e); + } + } + + private BedrockRuntimeAsyncClient buildBedrockRuntimeAsyncClient() { + AwsCredentialsProvider awsCredentialsProvider = connector.getSessionToken() != null + ? StaticCredentialsProvider + .create(AwsSessionCredentials.create(connector.getAccessKey(), connector.getSecretKey(), connector.getSessionToken())) + : StaticCredentialsProvider.create(AwsBasicCredentials.create(connector.getAccessKey(), connector.getSecretKey())); + + return BedrockRuntimeAsyncClient + .builder() + .region(Region.of(connector.getRegion())) + .credentialsProvider(awsCredentialsProvider) + .httpClient(httpClient) + .build(); + } + + private List parseSystemMessages(JsonNode systemArray) { + return systemArray + .findValuesAsText("text") + .stream() + .map(text -> SystemContentBlock.builder().text(text).build()) + .collect(Collectors.toList()); + } + + private List parseMessages(JsonNode messagesArray) { + List messages = new ArrayList<>(); + for (JsonNode messageItem : messagesArray) { + messages.add(buildMessage(messageItem)); + } + return messages; + } + + private Message buildMessage(JsonNode messageItem) { + String role = messageItem.has("role") && messageItem.get("role") != null ? messageItem.get("role").asText() : "assistant"; + + List contentBlocks = buildContentBlocks(messageItem.get("content")); + return Message.builder().role(role).content(contentBlocks).build(); + } + + private List buildContentBlocks(JsonNode contentArray) { + List blocks = new ArrayList<>(); + if (contentArray != null && contentArray.isArray()) { + for (JsonNode item : contentArray) { + addContentBlock(blocks, item); + } + } + return blocks; + } + + private void addContentBlock(List blocks, JsonNode item) { + if (item.has("text")) { + blocks.add(ContentBlock.builder().text(item.get("text").asText()).build()); + } + if (item.has("toolResult")) { + blocks.add(buildToolResultBlock(item.get("toolResult"))); + } + if (item.has("toolUse")) { + blocks.add(buildToolUseBlock(item.get("toolUse"))); + } + } + + private ContentBlock buildToolResultBlock(JsonNode toolResult) { + String text = extractResultText(toolResult.get("content")); + return ContentBlock + .builder() + .toolResult( + ToolResultBlock + .builder() + .toolUseId(toolResult.get("toolUseId").asText()) + .content(ToolResultContentBlock.builder().text(text).build()) + .build() + ) + .build(); + } + + private String extractResultText(JsonNode content) { + if (content.isArray() && content.size() > 0) { + return content.get(0).get("text").asText(); + } + return content.isTextual() ? content.asText() : ""; + } + + private ContentBlock buildToolUseBlock(JsonNode toolUse) { + Document input = toolUse.has("input") ? buildDocumentFromJsonNode(toolUse.get("input")) : Document.fromMap(Map.of()); + + return ContentBlock + .builder() + .toolUse( + software.amazon.awssdk.services.bedrockruntime.model.ToolUseBlock + .builder() + .toolUseId(toolUse.get("toolUseId").asText()) + .name(toolUse.get("name").asText()) + .input(input) + .build() + ) + .build(); + } + + private ToolConfiguration parseToolConfig(JsonNode toolConfig) { + if (!toolConfig.has("tools")) + return null; + + List tools = new ArrayList<>(); + for (JsonNode toolItem : toolConfig.get("tools")) { + if (toolItem.has("toolSpec")) { + tools.add(buildTool(toolItem.get("toolSpec"))); + } + } + return ToolConfiguration.builder().tools(tools).build(); + } + + private Tool buildTool(JsonNode toolSpec) { + Document schema = buildDocumentFromJsonNode(toolSpec.get("inputSchema").get("json")); + return Tool + .builder() + .toolSpec( + ToolSpecification + .builder() + .name(toolSpec.get("name").asText()) + .description(toolSpec.get("description").asText()) + .inputSchema(ToolInputSchema.builder().json(schema).build()) + .build() + ) + .build(); + } + + private Document buildDocumentFromJsonNode(JsonNode node) { + if (node.isObject()) { + Map map = new HashMap<>(); + node.fields().forEachRemaining(entry -> map.put(entry.getKey(), buildDocumentFromJsonNode(entry.getValue()))); + return Document.fromMap(map); + } + if (node.isArray()) { + List list = new ArrayList<>(); + for (JsonNode item : node) { + list.add(buildDocumentFromJsonNode(item)); + } + return Document.fromList(list); + } + if (node.isTextual()) + return Document.fromString(node.asText()); + if (node.isBoolean()) + return Document.fromBoolean(node.asBoolean()); + if (node.isNumber()) + return Document.fromNumber(node.isInt() ? node.asInt() : node.asDouble()); + return Document.fromString(node.toString()); + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java new file mode 100644 index 0000000000..15078dfccb --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/HttpStreamingHandler.java @@ -0,0 +1,292 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.remote.streaming; + +import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS; + +import java.io.IOException; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorClientConfig; +import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.utils.StringUtils; +import org.opensearch.ml.engine.algorithms.remote.ConnectorUtils; + +import com.jayway.jsonpath.JsonPath; + +import lombok.extern.log4j.Log4j2; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.internal.http2.StreamResetException; +import okhttp3.sse.EventSource; +import okhttp3.sse.EventSourceListener; +import okhttp3.sse.EventSources; + +@Log4j2 +public class HttpStreamingHandler extends BaseStreamingHandler { + + private final Connector connector; + private OkHttpClient okHttpClient; + private String llmInterface; + + public HttpStreamingHandler(String llmInterface, Connector connector, ConnectorClientConfig connectorClientConfig) { + this.connector = connector; + this.llmInterface = llmInterface; + + // Get connector client configuration + Duration connectionTimeout = Duration.ofSeconds(connectorClientConfig.getConnectionTimeout()); + Duration readTimeout = Duration.ofSeconds(connectorClientConfig.getReadTimeout()); + + // Initialize OkHttp client for SSE + try { + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + this.okHttpClient = new OkHttpClient.Builder() + .connectTimeout(connectionTimeout) + .readTimeout(readTimeout) + .retryOnConnectionFailure(true) + .build(); + return null; + }); + } catch (Exception e) { + throw new RuntimeException("Failed to build OkHttpClient", e); + } + } + + @Override + public void startStream( + String action, + Map parameters, + String payload, + StreamPredictActionListener actionListener + ) { + try { + log.info("Creating SSE connection for streaming request"); + EventSourceListener listener = new HTTPEventSourceListener(actionListener, llmInterface); + Request request = ConnectorUtils.buildOKHttpStreamingRequest(action, connector, parameters, payload); + + AccessController.doPrivileged((PrivilegedExceptionAction) () -> { + EventSources.createFactory(okHttpClient).newEventSource(request, listener); + return null; + }); + + } catch (Exception e) { + log.error("Failed to start HTTP streaming", e); + handleError(e, actionListener); + } + } + + @Override + public void handleError(Throwable error, StreamPredictActionListener listener) { + log.error("HTTP streaming error", error); + listener.onFailure(new MLException("Fail to execute streaming", error)); + } + + public final class HTTPEventSourceListener extends EventSourceListener { + private StreamPredictActionListener streamActionListener; + private final String llmInterface; + private AtomicBoolean isStreamClosed; + private boolean functionCallInProgress = false; + private boolean agentExecutionInProgress = false; + private String accumulatedToolCallId = null; + private String accumulatedToolName = null; + private String accumulatedArguments = ""; + + public HTTPEventSourceListener(StreamPredictActionListener streamActionListener, String llmInterface) { + this.streamActionListener = streamActionListener; + this.llmInterface = llmInterface; + this.isStreamClosed = new AtomicBoolean(false); + } + + /*** + * Callback when the SSE endpoint connection is made. + * @param eventSource the event source + * @param response the response + */ + @Override + public void onOpen(EventSource eventSource, Response response) { + log.debug("Connected to SSE Endpoint."); + } + + /*** + * For each event received from the SSE endpoint + * @param eventSource The event source + * @param id The id of the event + * @param type The type of the event which is used to filter + * @param data The event data + */ + @Override + public void onEvent(EventSource eventSource, String id, String type, String data) { + log.debug("The data is: {}", data); + switch (llmInterface) { + case LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS: + onOpenAIEvent(data); + break; + default: + throw new IllegalArgumentException(String.format("Unsupported llm interface: %s", llmInterface)); + } + } + + /*** + * When the connection is closed we receive this even which is currently only logged. + * @param eventSource The event source + */ + @Override + public void onClosed(EventSource eventSource) { + log.debug("SSE CLOSED."); + } + + /*** + * If there is any failure we log the error and the stack trace + * During stream resets with no errors we set the connected flag to false to allow the main thread to attempt a re-connect + * @param eventSource The event source + * @param t The error object + * @param response The response + */ + @Override + public void onFailure(EventSource eventSource, Throwable t, Response response) { + if (t != null) { + // Network/connection error + log.error("Error: " + t.getMessage(), t); + if (t instanceof StreamResetException && t.getMessage().contains("NO_ERROR")) { + // TODO: reconnect + } else { + streamActionListener.onFailure(new MLException("SSE failure with network error", t)); + } + } else if (response != null) { + // HTTP error (e.g., 400 Bad Request) + try { + String errorBody = response.body() != null ? response.body().string() : ""; + streamActionListener.onFailure(new MLException("Error from remote service: " + errorBody)); + } catch (IOException e) { + streamActionListener.onFailure(new MLException("SSE failure - unable to read error details")); + } + } else { + // Unknown failure + streamActionListener.onFailure(new MLException("SSE failure")); + } + } + + private void onOpenAIEvent(String data) { + if ("[DONE]".equals(data)) { + handleDoneEvent(); + return; + } + + // Process stream chunk + try { + Map dataMap = gson.fromJson(data, Map.class); + processStreamChunk(dataMap); + } catch (Exception e) { + log.debug("Skipping malformed chunk: {}", data); + } + } + + private void handleDoneEvent() { + if (!agentExecutionInProgress) { + sendCompletionResponse(isStreamClosed, streamActionListener); + } + } + + private void processStreamChunk(Map dataMap) { + // Handle stop finish reason + String finishReason = extractPath(dataMap, "$.choices[0].finish_reason"); + if ("stop".equals(finishReason)) { + agentExecutionInProgress = false; + sendCompletionResponse(isStreamClosed, streamActionListener); + return; + } + + // Process content + String content = extractPath(dataMap, "$.choices[0].delta.content"); + if (content != null && !content.isEmpty()) { + sendContentResponse(content, false, streamActionListener); + } + + // Process tool call + List toolCalls = extractPath(dataMap, "$.choices[0].delta.tool_calls"); + if (toolCalls != null) { + accumulateFunctionCall(toolCalls); + sendContentResponse(StringUtils.toJson(toolCalls), false, streamActionListener); + } + + // Handle tool_calls finish reason + if ("tool_calls".equals(finishReason) && functionCallInProgress) { + completeToolCall(); + } + } + + private T extractPath(Map dataMap, String path) { + try { + return JsonPath.read(dataMap, path); + } catch (Exception e) { + return null; + } + } + + private void completeToolCall() { + agentExecutionInProgress = true; + String completeFunctionCall = buildCompleteFunctionCallResponse(); + + // Send to client and agent + sendContentResponse(completeFunctionCall, false, streamActionListener); + Map response = gson.fromJson(completeFunctionCall, Map.class); + ModelTensorOutput output = createModelTensorOutput(response); + streamActionListener.onResponse(new MLTaskResponse(output)); + + // Reset state + functionCallInProgress = false; + } + + private String buildCompleteFunctionCallResponse() { + Map function = Map.of("name", accumulatedToolName, "arguments", accumulatedArguments); + Map toolCall = Map.of("id", accumulatedToolCallId, "type", "function", "function", function); + Map message = Map.of("tool_calls", List.of(toolCall)); + Map choice = Map.of("message", message, "finish_reason", "tool_calls"); + Map response = Map.of("choices", List.of(choice)); + + return StringUtils.toJson(response); + } + + private ModelTensorOutput createModelTensorOutput(Map responseData) { + ModelTensor tensor = ModelTensor.builder().name("response").dataAsMap(responseData).build(); + ModelTensors tensors = ModelTensors.builder().mlModelTensors(List.of(tensor)).build(); + return ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build(); + } + + private void accumulateFunctionCall(List toolCalls) { + functionCallInProgress = true; + for (Object toolCall : toolCalls) { + Map tcMap = (Map) toolCall; + + // Extract ID and name from first chunk + if (tcMap.containsKey("id")) { + accumulatedToolCallId = (String) tcMap.get("id"); + } + if (tcMap.containsKey("function")) { + Map func = (Map) tcMap.get("function"); + if (func.containsKey("name")) { + accumulatedToolName = (String) func.get("name"); + } + if (func.containsKey("arguments")) { + accumulatedArguments += (String) func.get("arguments"); + } + } + } + } + } +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/StreamPredictActionListener.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/StreamPredictActionListener.java similarity index 55% rename from ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/StreamPredictActionListener.java rename to ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/StreamPredictActionListener.java index 5c04336937..166b3871df 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/StreamPredictActionListener.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/StreamPredictActionListener.java @@ -3,8 +3,9 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.engine.algorithms.remote; +package org.opensearch.ml.engine.algorithms.remote.streaming; +import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -26,9 +27,24 @@ public class StreamPredictActionListener { private final TransportChannel channel; + private final ActionListener agentListener; + private final String memoryId; + private final String parentInteractionId; public StreamPredictActionListener(TransportChannel channel) { + this(channel, null, null, null); + } + + public StreamPredictActionListener( + TransportChannel channel, + ActionListener agentListener, + String memoryId, + String parentInteractionId + ) { this.channel = channel; + this.agentListener = agentListener; + this.memoryId = memoryId; + this.parentInteractionId = parentInteractionId; } /** @@ -40,7 +56,11 @@ public StreamPredictActionListener(TransportChannel channel) { */ public void onStreamResponse(Response response, boolean isLastBatch) { assert response != null; - channel.sendResponseBatch(response); + + // Add metadata to all responses + Response responseWithMetadata = addMetadataToResponse(response); + + channel.sendResponseBatch(responseWithMetadata); if (isLastBatch) { channel.completeStream(); } @@ -54,7 +74,11 @@ public void onStreamResponse(Response response, boolean isLastBatch) { */ @Override public final void onResponse(Response response) { - onStreamResponse(response, true); + onStreamResponse(response, false); + + if (agentListener != null) { + agentListener.onResponse(response); + } } @Override @@ -72,6 +96,38 @@ public void onFailure(Exception e) { } } + private Response addMetadataToResponse(Response response) { + if (!(response instanceof MLTaskResponse)) { + return response; + } + + // Only add metadata for agent streaming + if (agentListener == null) { + return response; + } + + MLTaskResponse mlResponse = (MLTaskResponse) response; + if (mlResponse.getOutput() instanceof ModelTensorOutput) { + ModelTensorOutput output = (ModelTensorOutput) mlResponse.getOutput(); + List updatedOutputs = new ArrayList<>(); + + // TODO: refactor this to handle other types of agents + for (ModelTensors tensors : output.getMlModelOutputs()) { + List updatedTensors = new ArrayList<>(); + + updatedTensors.add(ModelTensor.builder().name("memory_id").result(memoryId).build()); + updatedTensors.add(ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build()); + + updatedTensors.addAll(tensors.getMlModelTensors()); + updatedOutputs.add(ModelTensors.builder().mlModelTensors(updatedTensors).build()); + } + + ModelTensorOutput updatedOutput = ModelTensorOutput.builder().mlModelOutputs(updatedOutputs).build(); + return (Response) new MLTaskResponse(updatedOutput); + } + return response; + } + private MLTaskResponse createErrorResponse(Exception error) { String errorMessage = error.getMessage(); if (errorMessage == null || errorMessage.trim().isEmpty()) { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/StreamingHandler.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/StreamingHandler.java new file mode 100644 index 0000000000..c7e843893f --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/StreamingHandler.java @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.remote.streaming; + +import java.util.Map; + +import org.opensearch.ml.common.transport.MLTaskResponse; + +/** + * Streaming handler interface. + */ +public interface StreamingHandler { + void startStream( + String action, + Map parameters, + String payload, + StreamPredictActionListener listener + ); + + void handleError(Throwable error, StreamPredictActionListener listener); +} diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/StreamingHandlerFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/StreamingHandlerFactory.java new file mode 100644 index 0000000000..fb4ecf9fbf --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/streaming/StreamingHandlerFactory.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.remote.streaming; + +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS; + +import java.lang.reflect.Constructor; + +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorClientConfig; +import org.opensearch.ml.common.exception.MLException; + +import software.amazon.awssdk.http.async.SdkAsyncHttpClient; + +public class StreamingHandlerFactory { + + public static StreamingHandler createHandler( + String llmInterface, + Connector connector, + SdkAsyncHttpClient httpClient, + ConnectorClientConfig connectorClientConfig + ) { + switch (llmInterface.toLowerCase()) { + case LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE: + return createBedrockHandler(httpClient, connector); + case LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS: + return createHttpHandler(llmInterface, connector, connectorClientConfig); + default: + throw new IllegalArgumentException("Unsupported LLM interface: " + llmInterface); + } + } + + private static StreamingHandler createBedrockHandler(SdkAsyncHttpClient httpClient, Connector connector) { + try { + // Use reflection to avoid hard dependency + Class handlerClass = Class.forName("org.opensearch.ml.engine.algorithms.remote.streaming.BedrockStreamingHandler"); + Constructor constructor = handlerClass + .getConstructor(SdkAsyncHttpClient.class, Class.forName("org.opensearch.ml.common.connector.AwsConnector")); + return (StreamingHandler) constructor.newInstance(httpClient, connector); + } catch (ClassNotFoundException e) { + throw new MLException("Bedrock streaming not available - Bedrock SDK not found", e); + } catch (Exception e) { + throw new MLException("Failed to initialize Bedrock streaming handler", e); + } + } + + private static StreamingHandler createHttpHandler( + String llmInterface, + Connector connector, + ConnectorClientConfig connectorClientConfig + ) { + return new HttpStreamingHandler(llmInterface, connector, connectorClientConfig); + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java index 83b6517016..1237957ab3 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/MLEngineTest.java @@ -297,7 +297,7 @@ public void executeLocalSampleCalculator() throws Exception { LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) o; assertEquals(3.0, output.getResult(), 1e-5); }, e -> { fail("Test failed"); }); - mlEngine.execute(input, listener); + mlEngine.execute(input, listener, null); } @Test @@ -309,7 +309,7 @@ public void executeWithMetricsCorrelationThrowsException() throws Exception { inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); inputData.add(new float[] { 1.0f, 2.0f, 3.0f, 4.0f }); Input input = MetricsCorrelationInput.builder().inputData(inputData).build(); - mlEngine.execute(input, null); + mlEngine.execute(input, null, null); } @Test @@ -336,7 +336,7 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params LocalSampleCalculatorOutput output = (LocalSampleCalculatorOutput) o; assertEquals(3.0, output.getResult(), 1e-5); }, e -> { fail("Test failed"); }); - mlEngine.execute(input, listener); + mlEngine.execute(input, listener, null); } private MLModel trainKMeansModel() { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java index c3c9005c7b..538a643514 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutorTest.java @@ -271,7 +271,7 @@ public void test_HappyCase_ReturnsResult() throws IOException { ActionListener listener = invocation.getArgument(2); listener.onResponse(modelTensor); return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); + }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); Mockito.doAnswer(invocation -> { @@ -307,7 +307,7 @@ public void test_AgentRunnerReturnsListOfModelTensor_ReturnsResult() throws IOEx ActionListener> listener = invocation.getArgument(2); listener.onResponse(response); return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); + }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); Mockito.doAnswer(invocation -> { @@ -345,7 +345,7 @@ public void test_AgentRunnerReturnsListOfModelTensors_ReturnsResult() throws IOE ActionListener> listener = invocation.getArgument(2); listener.onResponse(response); return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); + }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); Mockito.doAnswer(invocation -> { @@ -379,7 +379,7 @@ public void test_AgentRunnerReturnsListOfString_ReturnsResult() throws IOExcepti ActionListener> listener = invocation.getArgument(2); listener.onResponse(response); return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); + }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); Mockito.doAnswer(invocation -> { @@ -413,7 +413,7 @@ public void test_AgentRunnerReturnsString_ReturnsResult() throws IOException { ActionListener listener = invocation.getArgument(2); listener.onResponse("response"); return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); + }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); Mockito.doReturn(mlAgentRunner).when(mlAgentExecutor).getAgentRunner(Mockito.any()); GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); Mockito.doAnswer(invocation -> { @@ -451,7 +451,7 @@ public void test_AgentRunnerReturnsModelTensorOutput_ReturnsResult() throws IOEx ActionListener listener = invocation.getArgument(2); listener.onResponse(modelTensorOutput); return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); + }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); Mockito.doAnswer(invocation -> { // Extract the ActionListener argument from the method invocation @@ -485,7 +485,7 @@ public void test_CreateConversation_ReturnsResult() throws IOException { ActionListener listener = invocation.getArgument(2); listener.onResponse(modelTensor); return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); + }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); Mockito.doAnswer(invocation -> { @@ -1088,7 +1088,7 @@ public void test_query_planning_agentic_search_enabled() throws IOException { ActionListener listener = invocation.getArgument(2); listener.onResponse(modelTensor); return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); + }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); // Execute the agent mlAgentExecutorWithEnabledSearch.execute(getAgentMLInput(), agentActionListener); @@ -1179,7 +1179,7 @@ public void test_ExistingConversation_WithMemoryAndParentInteractionId() throws ActionListener listener = invocation.getArgument(2); listener.onResponse(modelTensor); return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); + }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); Mockito.doAnswer(invocation -> { @@ -1218,7 +1218,7 @@ public void test_AgentFailure_UpdatesInteractionWithFailure() throws IOException ActionListener listener = invocation.getArgument(2); listener.onFailure(testException); return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); + }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); Mockito.doAnswer(invocation -> { @@ -1292,7 +1292,7 @@ public void test_ExecuteAgent_SyncMode() throws IOException { ActionListener listener = invocation.getArgument(2); listener.onResponse(modelTensor); return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); + }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); Mockito.doAnswer(invocation -> { @@ -1380,7 +1380,7 @@ public void test_UpdateInteractionWithFailure() throws IOException { ActionListener listener = invocation.getArgument(2); listener.onFailure(testException); return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); + }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); Mockito.doAnswer(invocation -> { @@ -1523,7 +1523,7 @@ public void test_UpdateInteractionFailure_LogLines() throws IOException { ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException("Test failure")); return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); + }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); Mockito.doAnswer(invocation -> { @@ -1563,7 +1563,7 @@ public void test_UpdateInteractionFailure_ErrorCallback() throws IOException { ActionListener listener = invocation.getArgument(2); listener.onFailure(new RuntimeException("Test failure")); return null; - }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any()); + }).when(mlAgentRunner).run(Mockito.any(), Mockito.any(), Mockito.any(), Mockito.any()); GetResponse agentGetResponse = prepareMLAgent("test-agent-id", false, null); Mockito.doAnswer(invocation -> { diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java index 237754dde0..f6c3e3618e 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/MLChatAgentRunnerTest.java @@ -200,7 +200,7 @@ public void testParsingJsonBlockFromResponse() { Map params = new HashMap<>(); params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); params.put("verbose", "true"); - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Capture the response passed to the listener ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); @@ -240,7 +240,7 @@ public void testParsingJsonBlockFromResponse2() { Map params = new HashMap<>(); params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); params.put("verbose", "true"); - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Capture the response passed to the listener ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); @@ -280,7 +280,7 @@ public void testParsingJsonBlockFromResponse3() { Map params = new HashMap<>(); params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); params.put("verbose", "true"); - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Capture the response passed to the listener ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); @@ -320,7 +320,7 @@ public void testParsingJsonBlockFromResponse4() { Map params = new HashMap<>(); params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); params.put("verbose", "false"); - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Capture the response passed to the listener ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(Object.class); @@ -364,7 +364,7 @@ public void testRunWithIncludeOutputNotSet() { .memory(mlMemorySpec) .tools(Arrays.asList(firstToolSpec, secondToolSpec)) .build(); - mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener); + mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener, null); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); @@ -394,7 +394,7 @@ public void testRunWithIncludeOutputMLModel() { .memory(mlMemorySpec) .tools(Arrays.asList(firstToolSpec, secondToolSpec)) .build(); - mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener); + mlChatAgentRunner.run(mlAgent, new HashMap<>(), agentActionListener, null); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); @@ -429,7 +429,7 @@ public void testRunWithIncludeOutputSet() { .tools(Arrays.asList(firstToolSpec, secondToolSpec)) .build(); HashMap params = new HashMap<>(); - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); ModelTensorOutput modelTensorOutput = (ModelTensorOutput) objectCaptor.getValue(); List agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors(); @@ -481,7 +481,7 @@ public void testChatHistoryExcludeOngoingQuestion() { HashMap params = new HashMap<>(); params.put(MESSAGE_HISTORY_LIMIT, "5"); - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); String chatHistory = params.get(MLChatAgentRunner.CHAT_HISTORY); Assert.assertFalse(chatHistory.contains("input-99")); @@ -537,7 +537,7 @@ private void testInteractions(String maxInteraction) { HashMap params = new HashMap<>(); params.put("verbose", "true"); - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); String chatHistory = params.get(MLChatAgentRunner.CHAT_HISTORY); Assert.assertFalse(chatHistory.contains("input-99")); @@ -566,7 +566,7 @@ public void testChatHistoryException() { }).when(conversationIndexMemory).getMessages(memoryInteractionCapture.capture(), messageHistoryLimitCapture.capture()); HashMap params = new HashMap<>(); - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Verifying that onFailure was called verify(agentActionListener).onFailure(any(RuntimeException.class)); @@ -584,7 +584,7 @@ public void testToolValidationSuccess() { Map params = createAgentParamsWithAction(FIRST_TOOL, "someInput"); // Run the MLChatAgentRunner - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Verify that the tool's run method was called verify(firstTool).run(any(), any()); @@ -606,7 +606,7 @@ public void testToolValidationFailure() { .when(firstTool) .run(Mockito.anyMap(), nextStepListenerCaptor.capture()); // Run the MLChatAgentRunner - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Verify that the tool's run method was not called verify(firstTool, never()).run(any(), any()); @@ -632,7 +632,7 @@ public void testToolNotFound() { Map params = createAgentParamsWithAction("nonExistentTool", "someInput"); // Run the MLChatAgentRunner - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Verify that no tool's run method was called verify(firstTool, never()).run(any(), any()); @@ -655,7 +655,7 @@ public void testToolFailure() { .when(firstTool) .run(Mockito.anyMap(), toolListenerCaptor.capture()); // Run the MLChatAgentRunner - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Verify that the tool's run method was called verify(firstTool).run(any(), any()); @@ -681,7 +681,7 @@ public void testToolThrowException() { .when(firstTool) .run(Mockito.anyMap(), toolListenerCaptor.capture()); // Run the MLChatAgentRunner - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Verify that the tool's run method was called verify(firstTool).run(any(), any()); @@ -703,7 +703,7 @@ public void testToolParameters() { Map params = createAgentParamsWithAction(FIRST_TOOL, "someInput"); // Run the MLChatAgentRunner. - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Verify that the tool's run method was called. verify(firstTool).run(any(), any()); @@ -731,7 +731,7 @@ public void testToolUseOriginalInput() { doReturn(true).when(firstTool).useOriginalInput(); // Run the MLChatAgentRunner. - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Verify that the tool's run method was called. verify(firstTool).run(any(), any()); @@ -769,7 +769,7 @@ public void testScratchpad_E2E_Flow() { MLAgent mlAgent = createMLAgentWithScratchpadTools(); Map params = new HashMap<>(); params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id_for_scratchpad_test"); - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Also verify the final response to the user verify(agentActionListener).onResponse(objectCaptor.capture()); @@ -797,7 +797,7 @@ public void testToolConfig() { doReturn(false).when(firstTool).useOriginalInput(); // Run the MLChatAgentRunner. - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Verify that the tool's run method was called. verify(firstTool).run(any(), any()); @@ -827,7 +827,7 @@ public void testToolConfigWithInputPlaceholder() { doReturn(false).when(firstTool).useOriginalInput(); // Run the MLChatAgentRunner. - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Verify that the tool's run method was called. verify(firstTool).run(any(), any()); @@ -860,7 +860,7 @@ public void testSaveLastTraceFailure() { return null; }).when(conversationIndexMemory).save(any(), any(), any(), any(), conversationIndexMemoryCapture.capture()); // Run the MLChatAgentRunner - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Verify that the tool's run method was called verify(firstTool).run(any(), any()); @@ -912,7 +912,7 @@ public void testToolExecutionWithChatHistoryParameter() { HashMap params = new HashMap<>(); params.put(MESSAGE_HISTORY_LIMIT, "5"); - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); Mockito.verify(agentActionListener).onResponse(objectCaptor.capture()); String chatHistory = params.get(MLChatAgentRunner.CHAT_HISTORY); Assert.assertFalse(chatHistory.contains("input-99")); @@ -1040,7 +1040,7 @@ public void testMaxIterationsReached() { Map params = new HashMap<>(); params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Verify response is captured verify(agentActionListener).onResponse(objectCaptor.capture()); @@ -1079,7 +1079,7 @@ public void testMaxIterationsReachedWithValidThought() { Map params = new HashMap<>(); params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id"); - mlChatAgentRunner.run(mlAgent, params, agentActionListener); + mlChatAgentRunner.run(mlAgent, params, agentActionListener, null); // Verify response is captured verify(agentActionListener).onResponse(objectCaptor.capture()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapperTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapperTest.java new file mode 100644 index 0000000000..79fba06c5b --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/agent/StreamingWrapperTest.java @@ -0,0 +1,264 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.algorithms.agent; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.ActionRequest; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.agent.LLMSpec; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.prediction.MLPredictionStreamTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.transport.TransportChannel; +import org.opensearch.transport.client.Client; + +public class StreamingWrapperTest { + + @Mock + private TransportChannel channel; + + @Mock + private Client client; + + @Mock + private ActionListener listener; + + @Mock + private ActionListener mlTaskListener; + + private StreamingWrapper streamingWrapper; + private StreamingWrapper nonStreamingWrapper; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + streamingWrapper = new StreamingWrapper(channel, client); + nonStreamingWrapper = new StreamingWrapper(null, client); + } + + @Test + public void testConstructor() { + StreamingWrapper wrapper = new StreamingWrapper(channel, client); + assertNotNull(wrapper); + } + + @Test + public void testFixInteractionRoleStreaming() { + List interactions = new ArrayList<>(); + interactions.add("{\"tool_calls\":[{\"name\":\"test\"}]}"); + + streamingWrapper.fixInteractionRole(interactions); + + assertTrue(interactions.get(0).contains("\"role\":\"assistant\"")); + } + + @Test + public void testFixInteractionRoleNonStreaming() { + List interactions = new ArrayList<>(); + interactions.add("{\"tool_calls\":[{\"name\":\"test\"}]}"); + + nonStreamingWrapper.fixInteractionRole(interactions); + + assertFalse(interactions.get(0).contains("\"role\":\"assistant\"")); + } + + @Test + public void testFixInteractionRoleWithEmptyInteractions() { + List interactions = new ArrayList<>(); + + streamingWrapper.fixInteractionRole(interactions); + + assertTrue(interactions.isEmpty()); + } + + @Test + public void testFixInteractionRoleWithExistingRole() { + List interactions = new ArrayList<>(); + interactions.add("{\"role\":\"user\",\"tool_calls\":[{\"name\":\"test\"}]}"); + + streamingWrapper.fixInteractionRole(interactions); + + assertTrue(interactions.get(0).contains("\"role\":\"user\"")); + assertFalse(interactions.get(0).contains("\"role\":\"assistant\"")); + } + + @Test + public void testCreatePredictionRequestStreaming() { + LLMSpec llm = mock(LLMSpec.class); + when(llm.getModelId()).thenReturn("test-model"); + Map parameters = new HashMap<>(); + parameters.put("key", "value"); + + ActionRequest request = streamingWrapper.createPredictionRequest(llm, parameters, "tenant1"); + + assertNotNull(request); + assertTrue(request instanceof MLPredictionTaskRequest); + MLPredictionTaskRequest mlRequest = (MLPredictionTaskRequest) request; + assertEquals("test-model", mlRequest.getModelId()); + assertFalse(mlRequest.isDispatchTask()); // Should be false for streaming + } + + @Test + public void testCreatePredictionRequestNonStreaming() { + LLMSpec llm = mock(LLMSpec.class); + when(llm.getModelId()).thenReturn("test-model"); + Map parameters = new HashMap<>(); + + ActionRequest request = nonStreamingWrapper.createPredictionRequest(llm, parameters, "tenant1"); + + MLPredictionTaskRequest mlRequest = (MLPredictionTaskRequest) request; + assertTrue(mlRequest.isDispatchTask()); // Should be true for non-streaming + } + + @Test + public void testExecuteRequestStreaming() { + MLPredictionTaskRequest request = mock(MLPredictionTaskRequest.class); + + streamingWrapper.executeRequest(request, mlTaskListener); + + verify(request).setStreamingChannel(channel); + verify(client).execute(eq(MLPredictionStreamTaskAction.INSTANCE), eq(request), eq(mlTaskListener)); + } + + @Test + public void testExecuteRequestNonStreaming() { + MLPredictionTaskRequest request = mock(MLPredictionTaskRequest.class); + + nonStreamingWrapper.executeRequest(request, mlTaskListener); + + verify(request, never()).setStreamingChannel(any()); + verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), eq(request), eq(mlTaskListener)); + } + + @Test + public void testSendCompletionChunkStreaming() throws Exception { + streamingWrapper.sendCompletionChunk("session1", "parent1"); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(MLTaskResponse.class); + verify(channel).sendResponseBatch(responseCaptor.capture()); + + MLTaskResponse response = responseCaptor.getValue(); + assertNotNull(response); + } + + @Test + public void testSendCompletionChunkNonStreaming() throws Exception { + nonStreamingWrapper.sendCompletionChunk("session1", "parent1"); + + verify(channel, never()).sendResponseBatch(any()); + } + + @Test + public void testSendCompletionChunkWithException() throws Exception { + doThrow(new RuntimeException("Channel error")).when(channel).sendResponseBatch(any()); + + // Should not throw exception, just log warning + streamingWrapper.sendCompletionChunk("session1", "parent1"); + + verify(channel).sendResponseBatch(any()); + } + + @Test + public void testSendFinalResponseStreaming() { + streamingWrapper.sendFinalResponse("session1", listener, "parent1", true, null, null, "answer"); + + verify(listener).onResponse("Streaming completed"); + } + + @Test + public void testSendToolResponseStreaming() throws Exception { + streamingWrapper.sendToolResponse("tool output", "session1", "parent1"); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(MLTaskResponse.class); + verify(channel).sendResponseBatch(responseCaptor.capture()); + + MLTaskResponse response = responseCaptor.getValue(); + ModelTensorOutput output = (ModelTensorOutput) response.getOutput(); + List tensors = output.getMlModelOutputs().get(0).getMlModelTensors(); + + // Verify response tensor contains the tool output + boolean foundContent = false; + for (ModelTensor tensor : tensors) { + if ("response".equals(tensor.getName()) && tensor.getDataAsMap() != null) { + Map dataMap = tensor.getDataAsMap(); + if ("tool output".equals(dataMap.get("content"))) { + foundContent = true; + assertFalse((Boolean) dataMap.get("is_last")); + } + } + } + assertTrue(foundContent); + } + + @Test + public void testSendToolResponseNonStreaming() throws Exception { + nonStreamingWrapper.sendToolResponse("tool output", "session1", "parent1"); + + verify(channel, never()).sendResponseBatch(any()); + } + + @Test + public void testSendToolResponseWithException() throws Exception { + doThrow(new RuntimeException("Channel error")).when(channel).sendResponseBatch(any()); + + // Should not throw exception, just log error + streamingWrapper.sendToolResponse("tool output", "session1", "parent1"); + + verify(channel).sendResponseBatch(any()); + } + + @Test + public void testCreateStreamChunkStructure() throws Exception { + streamingWrapper.sendCompletionChunk("test-session", "test-parent"); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(MLTaskResponse.class); + verify(channel).sendResponseBatch(responseCaptor.capture()); + + MLTaskResponse response = responseCaptor.getValue(); + ModelTensorOutput output = (ModelTensorOutput) response.getOutput(); + List tensors = output.getMlModelOutputs().get(0).getMlModelTensors(); + + assertEquals(3, tensors.size()); + + // Find specific tensors by name + ModelTensor memoryTensor = tensors.stream().filter(t -> "memory_id".equals(t.getName())).findFirst().orElse(null); + ModelTensor parentTensor = tensors.stream().filter(t -> "parent_interaction_id".equals(t.getName())).findFirst().orElse(null); + ModelTensor responseTensor = tensors.stream().filter(t -> "response".equals(t.getName())).findFirst().orElse(null); + + assertNotNull(memoryTensor); + assertNotNull(parentTensor); + assertNotNull(responseTensor); + + assertEquals("test-session", memoryTensor.getResult()); + assertEquals("test-parent", parentTensor.getResult()); + assertTrue((Boolean) responseTensor.getDataAsMap().get("is_last")); + + } +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutorTest.java index 1b18565a21..56ad15cbd4 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AbstractConnectorExecutorTest.java @@ -3,27 +3,17 @@ import static org.junit.Assert.*; import static org.mockito.Mockito.*; -import java.util.concurrent.atomic.AtomicBoolean; - import org.junit.Before; import org.junit.Test; -import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.connector.ConnectorClientConfig; -import org.opensearch.ml.common.output.model.ModelTensor; -import org.opensearch.ml.common.output.model.ModelTensorOutput; -import org.opensearch.ml.common.output.model.ModelTensors; -import org.opensearch.ml.common.transport.MLTaskResponse; public class AbstractConnectorExecutorTest { @Mock private AwsConnector mockConnector; - @Mock - private StreamPredictActionListener mockActionListener; - private ConnectorClientConfig connectorClientConfig; private AbstractConnectorExecutor executor; @@ -52,49 +42,4 @@ public void testValidateWithNonNullConfigButNullValues() { assertEquals(ConnectorClientConfig.CONNECTION_TIMEOUT_DEFAULT_VALUE, executor.getConnectorClientConfig().getConnectionTimeout()); assertEquals(ConnectorClientConfig.READ_TIMEOUT_DEFAULT_VALUE, executor.getConnectorClientConfig().getReadTimeout()); } - - @Test - public void testSendContentResponse() { - String content = "test content"; - boolean isLast = false; - - executor.sendContentResponse(content, isLast, mockActionListener); - - ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(MLTaskResponse.class); - ArgumentCaptor isLastCaptor = ArgumentCaptor.forClass(Boolean.class); - - verify(mockActionListener).onStreamResponse(responseCaptor.capture(), isLastCaptor.capture()); - - MLTaskResponse response = responseCaptor.getValue(); - ModelTensorOutput output = (ModelTensorOutput) response.getOutput(); - ModelTensors tensors = output.getMlModelOutputs().get(0); - ModelTensor tensor = tensors.getMlModelTensors().get(0); - - assertEquals("response", tensor.getName()); - assertEquals(content, tensor.getDataAsMap().get("content")); - assertEquals(isLast, tensor.getDataAsMap().get("is_last")); - assertEquals(isLast, isLastCaptor.getValue()); - } - - @Test - public void testSendContentResponseWithLastFlag() { - String content = "final content"; - boolean isLast = true; - - executor.sendContentResponse(content, isLast, mockActionListener); - - ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(MLTaskResponse.class); - ArgumentCaptor isLastCaptor = ArgumentCaptor.forClass(Boolean.class); - - verify(mockActionListener).onStreamResponse(responseCaptor.capture(), isLastCaptor.capture()); - - assertTrue(isLastCaptor.getValue()); - } - - @Test - public void testSendCompletionResponseAlreadyClosed() { - AtomicBoolean isStreamClosed = new AtomicBoolean(true); - executor.sendCompletionResponse(isStreamClosed, mockActionListener); - verify(mockActionListener, never()).onStreamResponse(any(), anyBoolean()); - } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java index 08bf9d4edf..b1234dae93 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutorTest.java @@ -12,7 +12,6 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -64,6 +63,7 @@ import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener; import org.opensearch.ml.engine.encryptor.Encryptor; import org.opensearch.ml.engine.encryptor.EncryptorImpl; import org.opensearch.script.ScriptService; @@ -267,7 +267,7 @@ public void executePredict_TextDocsInferenceInput_withStepSize() { ); verify(actionListener, times(0)).onFailure(any()); - verify(executor, times(3)).preparePayloadAndInvoke(anyString(), any(), any(), any(), any()); + verify(executor, times(3)).preparePayloadAndInvoke(anyString(), any(), any(), any()); } @Test @@ -946,24 +946,6 @@ public Void answer(InvocationOnMock invocation) { assertEquals("test failure retryable", exceptionArgumentCaptor.getValue().getSuppressed()[1].getMessage()); } - @Test - public void testInvokeRemoteServiceStream_ValidInterface() { - AwsConnector mockConnector = mock(AwsConnector.class); - when(mockConnector.getAccessKey()).thenReturn("test-access-key"); - when(mockConnector.getSecretKey()).thenReturn("test-secret-key"); - when(mockConnector.getRegion()).thenReturn("us-east-1"); - - AwsConnectorExecutor executor = new AwsConnectorExecutor(mockConnector); - MLInput mlInput = mock(MLInput.class); - Map parameters = Map.of("_llm_interface", "bedrock/converse/claude", "model", "claude-v2", "inputs", "test input"); - String payload = "test payload"; - ExecutionContext executionContext = new ExecutionContext(123); - StreamPredictActionListener actionListener = mock(StreamPredictActionListener.class); - - executor.invokeRemoteServiceStream("predict", mlInput, parameters, payload, executionContext, actionListener); - verify(actionListener, never()).onFailure(any()); - } - @Test public void testInvokeRemoteServiceStream_WithException() { AwsConnector mockConnector = mock(AwsConnector.class); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java index e605c81323..a5821944d0 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/ConnectorUtilsTest.java @@ -516,7 +516,7 @@ public void buildSdkRequest_InvalidEndpoint_ThrowException() { } @Test - public void testBuildOKHttpRequestPOST_WithPayload() { + public void testBuildOKHttpStreamingRequest_WithPayload() { ConnectorAction predictAction = ConnectorAction .builder() .actionType(PREDICT) @@ -540,7 +540,7 @@ public void testBuildOKHttpRequestPOST_WithPayload() { Map parameters = ImmutableMap.of("input", "test input"); String payload = "{\"input\": \"test input\"}"; - Request request = ConnectorUtils.buildOKHttpRequestPOST(PREDICT.name(), connector, parameters, payload); + Request request = ConnectorUtils.buildOKHttpStreamingRequest(PREDICT.name(), connector, parameters, payload); assertEquals("POST", request.method()); assertEquals("http://test.com/mock", request.url().toString()); @@ -552,7 +552,7 @@ public void testBuildOKHttpRequestPOST_WithPayload() { } @Test - public void testBuildOKHttpRequestPOST_NullPayload() { + public void testBuildOKHttpStreamingRequest_NullPayload() { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("Content length is 0. Aborting request to remote model"); @@ -573,11 +573,11 @@ public void testBuildOKHttpRequestPOST_NullPayload() { .build(); Map parameters = new HashMap<>(); - ConnectorUtils.buildOKHttpRequestPOST(PREDICT.name(), connector, parameters, null); + ConnectorUtils.buildOKHttpStreamingRequest(PREDICT.name(), connector, parameters, null); } @Test - public void testBuildOKHttpRequestPOST_NoHeaders() { + public void testBuildOKHttpStreamingRequest_NoHeaders() { ConnectorAction predictAction = ConnectorAction .builder() .actionType(PREDICT) @@ -597,7 +597,7 @@ public void testBuildOKHttpRequestPOST_NoHeaders() { Map parameters = new HashMap<>(); String payload = "{\"input\": \"test input\"}"; - Request request = ConnectorUtils.buildOKHttpRequestPOST(PREDICT.name(), connector, parameters, payload); + Request request = ConnectorUtils.buildOKHttpStreamingRequest(PREDICT.name(), connector, parameters, payload); assertEquals("POST", request.method()); assertEquals("http://test.com/mock", request.url().toString()); @@ -608,7 +608,7 @@ public void testBuildOKHttpRequestPOST_NoHeaders() { } @Test - public void testBuildOKHttpRequestPOST_WithParameters() { + public void testBuildOKHttpStreamingRequest_WithParameters() { ConnectorAction predictAction = ConnectorAction .builder() .actionType(PREDICT) @@ -628,7 +628,7 @@ public void testBuildOKHttpRequestPOST_WithParameters() { Map parameters = ImmutableMap.of("model", "gpt-3.5", "input", "test input"); String payload = "{\"input\": \"test input\"}"; - Request request = ConnectorUtils.buildOKHttpRequestPOST(PREDICT.name(), connector, parameters, payload); + Request request = ConnectorUtils.buildOKHttpStreamingRequest(PREDICT.name(), connector, parameters, payload); assertEquals("POST", request.method()); assertEquals("http://test.com/mock/gpt-3.5", request.url().toString()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java index 0e7c540541..feb8dc173c 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutorTest.java @@ -40,6 +40,7 @@ import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener; import com.google.common.collect.ImmutableMap; diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java index be70680e91..b6a8d52567 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/RemoteConnectorExecutorTest.java @@ -142,7 +142,7 @@ public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDisabled( Exception exception = Assert .assertThrows( IllegalArgumentException.class, - () -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener, null) + () -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener) ); assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role"); } @@ -162,7 +162,7 @@ public void executePreparePayloadAndInvoke_SkipValidateMissingParameterEnabled() String actionType = inputDataSet.getActionType().toString(); MLInput mlInput = MLInput.builder().algorithm(FunctionName.TEXT_EMBEDDING).inputDataset(inputDataSet).build(); - executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener, null); + executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener); Mockito .verify(executor, times(1)) .invokeRemoteService(any(), any(), any(), argThat(argument -> argument.contains("You are a ${parameters.role}")), any(), any()); @@ -185,7 +185,7 @@ public void executePreparePayloadAndInvoke_SkipValidateMissingParameterDefault() Exception exception = Assert .assertThrows( IllegalArgumentException.class, - () -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener, null) + () -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener) ); assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role"); } @@ -217,7 +217,7 @@ public void executePreparePayloadAndInvoke_PassingParameter() { Exception exception = Assert .assertThrows( IllegalArgumentException.class, - () -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener, null) + () -> executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener) ); assert exception.getMessage().contains("Some parameter placeholder not filled in payload: role"); } @@ -242,7 +242,7 @@ public void executePreparePayloadAndInvoke_GetParamsIOException() throws Excepti .inputDataset(inputDataSet) .build(); - executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener, null); + executor.preparePayloadAndInvoke(actionType, mlInput, null, actionListener); verify(actionListener).onFailure(argThat(e -> e instanceof IOException && e.getMessage().contains("UT test IOException"))); } @@ -372,11 +372,18 @@ public void executeAction_WithTransportChannel() { ModelTensors mockTensors = mock(ModelTensors.class); listener.onResponse(new Tuple<>(200, mockTensors)); return null; - }).when(executor).preparePayloadAndInvoke(any(), any(), any(), any(), any()); + }).when(executor).preparePayloadAndInvoke(any(), any(), any(), any(), any(), any()); executor.executeAction(PREDICT.name(), mlInput, streamActionListener, channel); verify(executor, times(1)) - .preparePayloadAndInvoke(eq(PREDICT.name()), eq(mlInput), any(ExecutionContext.class), any(ActionListener.class), eq(channel)); + .preparePayloadAndInvoke( + eq(PREDICT.name()), + eq(mlInput), + any(ExecutionContext.class), + any(ActionListener.class), + any(ActionListener.class), + eq(channel) + ); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(MLTaskResponse.class); verify(streamActionListener, times(1)).onResponse(responseCaptor.capture()); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/StreamPredictActionListenerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/StreamPredictActionListenerTest.java index 5c880c4e0f..0f587cae9f 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/StreamPredictActionListenerTest.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/StreamPredictActionListenerTest.java @@ -19,6 +19,7 @@ import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener; import org.opensearch.transport.TransportChannel; import org.opensearch.transport.TransportRequest; @@ -59,7 +60,6 @@ public void testOnResponse_CallsOnStreamResponseWithLastBatch() { listener.onResponse(mockResponse); verify(mockChannel).sendResponseBatch(mockResponse); - verify(mockChannel).completeStream(); } @Test diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/streaming/BaseStreamingHandlerTest.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/streaming/BaseStreamingHandlerTest.java new file mode 100644 index 0000000000..6ff2d3a9e6 --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/remote/streaming/BaseStreamingHandlerTest.java @@ -0,0 +1,89 @@ +package org.opensearch.ml.engine.algorithms.remote.streaming; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS; + +import java.util.concurrent.atomic.AtomicBoolean; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.ml.common.connector.Connector; +import org.opensearch.ml.common.connector.ConnectorClientConfig; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.transport.MLTaskResponse; + +public class BaseStreamingHandlerTest { + + @Mock + private StreamPredictActionListener mockActionListener; + + @Mock + private Connector mockConnector; + + private BaseStreamingHandler streamingHandler; + private ConnectorClientConfig connectorClientConfig; + + @Before + public void setUp() { + MockitoAnnotations.openMocks(this); + connectorClientConfig = new ConnectorClientConfig(); + + streamingHandler = new HttpStreamingHandler(LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS, mockConnector, connectorClientConfig); + + } + + @Test + public void testSendContentResponse() { + String content = "test content"; + boolean isLast = false; + + streamingHandler.sendContentResponse(content, isLast, mockActionListener); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(MLTaskResponse.class); + ArgumentCaptor isLastCaptor = ArgumentCaptor.forClass(Boolean.class); + + verify(mockActionListener).onStreamResponse(responseCaptor.capture(), isLastCaptor.capture()); + + MLTaskResponse response = responseCaptor.getValue(); + ModelTensorOutput output = (ModelTensorOutput) response.getOutput(); + ModelTensors tensors = output.getMlModelOutputs().get(0); + ModelTensor tensor = tensors.getMlModelTensors().get(0); + + assertEquals("response", tensor.getName()); + assertEquals(content, tensor.getDataAsMap().get("content")); + assertEquals(isLast, tensor.getDataAsMap().get("is_last")); + assertEquals(isLast, isLastCaptor.getValue()); + } + + @Test + public void testSendContentResponseWithLastFlag() { + String content = "final content"; + boolean isLast = true; + + streamingHandler.sendContentResponse(content, isLast, mockActionListener); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(MLTaskResponse.class); + ArgumentCaptor isLastCaptor = ArgumentCaptor.forClass(Boolean.class); + + verify(mockActionListener).onStreamResponse(responseCaptor.capture(), isLastCaptor.capture()); + + assertTrue(isLastCaptor.getValue()); + } + + @Test + public void testSendCompletionResponseAlreadyClosed() { + AtomicBoolean isStreamClosed = new AtomicBoolean(true); + streamingHandler.sendCompletionResponse(isStreamClosed, mockActionListener); + verify(mockActionListener, never()).onStreamResponse(any(), anyBoolean()); + } +} diff --git a/plugin/build.gradle b/plugin/build.gradle index 5245df707b..84424ae092 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -353,6 +353,7 @@ List jacocoExclusions = [ 'org.opensearch.ml.task.MLTaskDispatcher', 'org.opensearch.ml.task.MLTrainAndPredictTaskRunner', 'org.opensearch.ml.task.MLExecuteTaskRunner', + 'org.opensearch.ml.task.MLExecuteTaskRunner.*', 'org.opensearch.ml.action.profile.MLProfileTransportAction', 'org.opensearch.ml.breaker.DiskCircuitBreaker', 'org.opensearch.ml.autoredeploy.MLModelAutoReDeployer.SearchRequestBuilderFactory', @@ -380,7 +381,9 @@ List jacocoExclusions = [ 'org.opensearch.ml.rest.RestMLDeleteMemoryAction', 'org.opensearch.ml.rest.RestMLDeleteMemoryAction.*', 'org.opensearch.ml.rest.RestMLPredictionStreamAction', - 'org.opensearch.ml.rest.RestMLPredictionStreamAction.*' + 'org.opensearch.ml.rest.RestMLPredictionStreamAction.*', + 'org.opensearch.ml.rest.RestMLExecuteStreamAction', + 'org.opensearch.ml.rest.RestMLExecuteStreamAction.*' ] jacocoTestCoverageVerification { diff --git a/plugin/src/main/java/org/opensearch/ml/action/execute/TransportExecuteStreamTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/execute/TransportExecuteStreamTaskAction.java new file mode 100644 index 0000000000..0d9801111e --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/action/execute/TransportExecuteStreamTaskAction.java @@ -0,0 +1,97 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.execute; + +import static org.opensearch.ml.plugin.MachineLearningPlugin.STREAM_EXECUTE_THREAD_POOL; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.common.Nullable; +import org.opensearch.common.inject.Inject; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.transport.execute.MLExecuteStreamTaskAction; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; +import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener; +import org.opensearch.ml.task.MLExecuteTaskRunner; +import org.opensearch.ml.task.MLTaskRunner; +import org.opensearch.tasks.Task; +import org.opensearch.transport.StreamTransportService; +import org.opensearch.transport.TransportChannel; +import org.opensearch.transport.TransportService; + +import lombok.AccessLevel; +import lombok.experimental.FieldDefaults; +import lombok.extern.log4j.Log4j2; + +@Log4j2 +@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE) +public class TransportExecuteStreamTaskAction extends HandledTransportAction { + MLTaskRunner mlExecuteTaskRunner; + TransportService transportService; + + public static StreamTransportService streamTransportService; + private static StreamTransportService streamTransportServiceInstance; + + @Inject + public TransportExecuteStreamTaskAction( + TransportService transportService, + ActionFilters actionFilters, + MLExecuteTaskRunner mlExecuteTaskRunner, + @Nullable StreamTransportService streamTransportService + ) { + super(MLExecuteStreamTaskAction.NAME, transportService, actionFilters, MLExecuteTaskRequest::new); + this.mlExecuteTaskRunner = mlExecuteTaskRunner; + this.transportService = transportService; + if (streamTransportServiceInstance == null) { + streamTransportServiceInstance = streamTransportService; + } + this.streamTransportService = streamTransportServiceInstance; + + if (streamTransportService != null) { + streamTransportService + .registerRequestHandler( + MLExecuteStreamTaskAction.NAME, + STREAM_EXECUTE_THREAD_POOL, + MLExecuteTaskRequest::new, + this::messageReceived + ); + } else { + log.warn("StreamTransportService is not available."); + } + } + + public static StreamTransportService getStreamTransportService() { + return streamTransportService; + } + + public void messageReceived(MLExecuteTaskRequest request, TransportChannel channel, Task task) { + StreamPredictActionListener streamListener = new StreamPredictActionListener<>( + channel + ); + doExecute(task, request, streamListener, channel); + } + + @Override + protected void doExecute(Task task, ActionRequest request, ActionListener listener) { + // This should never be called for streaming action + listener.onFailure(new UnsupportedOperationException("Use doExecute with TransportChannel for streaming requests")); + } + + protected void doExecute(Task task, ActionRequest request, ActionListener listener, TransportChannel channel) { + MLExecuteTaskRequest mlExecuteTaskRequest = MLExecuteTaskRequest.fromActionRequest(request); + mlExecuteTaskRequest.setStreamingChannel(channel); + + if (mlExecuteTaskRequest.getStreamingChannel() != null) { + mlExecuteTaskRequest.setDispatchTask(false); + } + + FunctionName functionName = mlExecuteTaskRequest.getFunctionName(); + mlExecuteTaskRunner.run(functionName, mlExecuteTaskRequest, streamTransportService, listener); + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionStreamTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionStreamTaskAction.java index 7947d59529..97fc07aac5 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionStreamTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionStreamTaskAction.java @@ -29,7 +29,7 @@ import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionStreamTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; -import org.opensearch.ml.engine.algorithms.remote.StreamPredictActionListener; +import org.opensearch.ml.engine.algorithms.remote.streaming.StreamPredictActionListener; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.model.MLModelCacheHelper; import org.opensearch.ml.model.MLModelManager; @@ -133,8 +133,12 @@ public void messageReceived(MLPredictionTaskRequest request, TransportChannel ch @Override protected void doExecute(Task task, ActionRequest request, ActionListener listener) { - // This should never be called for streaming action - listener.onFailure(new UnsupportedOperationException("Use doExecute with TransportChannel for streaming requests")); + TransportChannel channel = ((MLPredictionTaskRequest) request).getStreamingChannel(); + if (channel != null) { + doExecute(task, request, listener, channel); + } else { + listener.onFailure(new UnsupportedOperationException("Use doExecute with TransportChannel for streaming requests")); + } } protected void doExecute(Task task, ActionRequest request, ActionListener listener, TransportChannel channel) { diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index e40a848322..a45158ff89 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -97,6 +97,7 @@ import org.opensearch.ml.action.controller.UpdateControllerTransportAction; import org.opensearch.ml.action.deploy.TransportDeployModelAction; import org.opensearch.ml.action.deploy.TransportDeployModelOnNodeAction; +import org.opensearch.ml.action.execute.TransportExecuteStreamTaskAction; import org.opensearch.ml.action.execute.TransportExecuteTaskAction; import org.opensearch.ml.action.forward.TransportForwardAction; import org.opensearch.ml.action.handler.MLSearchHandler; @@ -197,6 +198,7 @@ import org.opensearch.ml.common.transport.controller.MLUpdateControllerAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelAction; import org.opensearch.ml.common.transport.deploy.MLDeployModelOnNodeAction; +import org.opensearch.ml.common.transport.execute.MLExecuteStreamTaskAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; import org.opensearch.ml.common.transport.forward.MLForwardAction; import org.opensearch.ml.common.transport.indexInsight.MLIndexInsightConfigGetAction; @@ -333,6 +335,7 @@ import org.opensearch.ml.rest.RestMLDeleteTaskAction; import org.opensearch.ml.rest.RestMLDeployModelAction; import org.opensearch.ml.rest.RestMLExecuteAction; +import org.opensearch.ml.rest.RestMLExecuteStreamAction; import org.opensearch.ml.rest.RestMLGetAgentAction; import org.opensearch.ml.rest.RestMLGetConfigAction; import org.opensearch.ml.rest.RestMLGetConnectorAction; @@ -460,6 +463,7 @@ public class MachineLearningPlugin extends Plugin public static final String GENERAL_THREAD_POOL = "opensearch_ml_general"; public static final String SDK_CLIENT_THREAD_POOL = "opensearch_ml_sdkclient"; public static final String EXECUTE_THREAD_POOL = "opensearch_ml_execute"; + public static final String STREAM_EXECUTE_THREAD_POOL = "opensearch_ml_execute_stream"; public static final String TRAIN_THREAD_POOL = "opensearch_ml_train"; public static final String PREDICT_THREAD_POOL = "opensearch_ml_predict"; public static final String STREAM_PREDICT_THREAD_POOL = "opensearch_ml_predict_stream"; @@ -525,6 +529,7 @@ public MachineLearningPlugin() {} .of( new ActionHandler<>(MLStatsNodesAction.INSTANCE, MLStatsNodesTransportAction.class), new ActionHandler<>(MLExecuteTaskAction.INSTANCE, TransportExecuteTaskAction.class), + new ActionHandler<>(MLExecuteStreamTaskAction.INSTANCE, TransportExecuteStreamTaskAction.class), new ActionHandler<>(MLPredictionTaskAction.INSTANCE, TransportPredictionTaskAction.class), new ActionHandler<>(MLPredictionStreamTaskAction.INSTANCE, TransportPredictionStreamTaskAction.class), new ActionHandler<>(MLTrainingTaskAction.INSTANCE, TransportTrainingTaskAction.class), @@ -960,6 +965,7 @@ public List getRestHandlers( clusterService ); RestMLExecuteAction restMLExecuteAction = new RestMLExecuteAction(mlFeatureEnabledSetting); + RestMLExecuteStreamAction restMlExecuteStreamAction = new RestMLExecuteStreamAction(mlFeatureEnabledSetting, clusterService); RestMLGetModelAction restMLGetModelAction = new RestMLGetModelAction(mlFeatureEnabledSetting); RestMLDeleteModelAction restMLDeleteModelAction = new RestMLDeleteModelAction(mlFeatureEnabledSetting); RestMLSearchModelAction restMLSearchModelAction = new RestMLSearchModelAction(mlFeatureEnabledSetting); @@ -1062,6 +1068,7 @@ public List getRestHandlers( restMLPredictionAction, restMLPredictionStreamAction, restMLExecuteAction, + restMlExecuteStreamAction, restMLTrainAndPredictAction, restMLGetModelAction, restMLDeleteModelAction, @@ -1217,6 +1224,14 @@ public List> getExecutorBuilders(Settings settings) { ML_THREAD_POOL_PREFIX + STREAM_PREDICT_THREAD_POOL, false ); + FixedExecutorBuilder streamExecuteThreadPool = new FixedExecutorBuilder( + settings, + STREAM_EXECUTE_THREAD_POOL, + OpenSearchExecutors.allocatedProcessors(settings) * 10, + 1000000, + ML_THREAD_POOL_PREFIX + STREAM_EXECUTE_THREAD_POOL, + false + ); FixedExecutorBuilder mcpThreadPool = new FixedExecutorBuilder( settings, MCP_TOOLS_SYNC_THREAD_POOL, @@ -1247,6 +1262,7 @@ public List> getExecutorBuilders(Settings settings) { batchIngestThreadPool, sdkClientThreadPool, streamPredictThreadPool, + streamExecuteThreadPool, mcpThreadPool, agenticMemoryThreadPool ); diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java new file mode 100644 index 0000000000..b76ca23f85 --- /dev/null +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestMLExecuteStreamAction.java @@ -0,0 +1,344 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import static org.opensearch.ml.plugin.MachineLearningPlugin.STREAM_EXECUTE_THREAD_POOL; +import static org.opensearch.ml.utils.MLExceptionUtils.AGENT_FRAMEWORK_DISABLED_ERR_MSG; +import static org.opensearch.ml.utils.MLExceptionUtils.STREAM_DISABLED_ERR_MSG; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID; +import static org.opensearch.ml.utils.RestActionUtils.isAsync; +import static org.opensearch.ml.utils.TenantAwareHelper.getTenantID; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.lease.Releasable; +import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.support.XContentHttpChunk; +import org.opensearch.core.common.bytes.BytesReference; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.http.HttpChunk; +import org.opensearch.ml.action.execute.TransportExecuteStreamTaskAction; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.input.Input; +import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.input.execute.agent.AgentMLInput; +import org.opensearch.ml.common.output.model.ModelTensor; +import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.output.model.ModelTensors; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.MLTaskResponse; +import org.opensearch.ml.common.transport.execute.MLExecuteStreamTaskAction; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; +import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; +import org.opensearch.ml.repackage.com.google.common.collect.ImmutableList; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.StreamingRestChannel; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.StreamTransportResponseHandler; +import org.opensearch.transport.StreamTransportService; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportRequestOptions; +import org.opensearch.transport.client.node.NodeClient; +import org.opensearch.transport.stream.StreamTransportResponse; + +import lombok.extern.log4j.Log4j2; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +@Log4j2 +public class RestMLExecuteStreamAction extends BaseRestHandler { + + private static final String ML_EXECUTE_STREAM_ACTION = "ml_execute_stream_action"; + private final MLFeatureEnabledSetting mlFeatureEnabledSetting; + private ClusterService clusterService; + + /** + * Constructor + */ + public RestMLExecuteStreamAction(MLFeatureEnabledSetting mlFeatureEnabledSetting, ClusterService clusterService) { + this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; + this.clusterService = clusterService; + } + + @Override + public String getName() { + return ML_EXECUTE_STREAM_ACTION; + } + + @Override + public List routes() { + return ImmutableList + .of( + new Route( + RestRequest.Method.POST, + String.format(Locale.ROOT, "%s/agents/{%s}/_execute/stream", ML_BASE_URI, PARAMETER_AGENT_ID) + ) + ); + } + + @Override + public boolean supportsContentStream() { + return true; + } + + @Override + public boolean supportsStreaming() { + return true; + } + + @Override + public boolean allowsUnsafeBuffers() { + return true; + } + + @Override + public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + if (!mlFeatureEnabledSetting.isStreamEnabled()) { + throw new IllegalStateException(STREAM_DISABLED_ERR_MSG); + } + + String agentId = request.param(PARAMETER_AGENT_ID); + + final StreamingRestChannelConsumer consumer = (channel) -> { + Map> headers = Map + .of( + "Content-Type", + List.of("text/event-stream"), + "Cache-Control", + List.of("no-cache"), + "Connection", + List.of("keep-alive") + ); + channel.prepareResponse(RestStatus.OK, headers); + + Flux.from(channel).ofType(HttpChunk.class).concatMap(chunk -> { + final CompletableFuture future = new CompletableFuture<>(); + try { + MLExecuteTaskRequest mlExecuteTaskRequest = getRequest(agentId, request, chunk.content()); + StreamTransportResponseHandler handler = new StreamTransportResponseHandler() { + @Override + public void handleStreamResponse(StreamTransportResponse streamResponse) { + try { + MLTaskResponse response = streamResponse.nextResponse(); + + if (response != null) { + HttpChunk responseChunk = convertToHttpChunk(response); + channel.sendChunk(responseChunk); + + // Recursively handle the next response + client + .threadPool() + .executor(STREAM_EXECUTE_THREAD_POOL) + .execute(() -> handleStreamResponse(streamResponse)); + } else { + log.info("No more responses, closing stream"); + future.complete(XContentHttpChunk.last()); + streamResponse.close(); + } + } catch (Exception e) { + future.completeExceptionally(e); + log.error("Error in stream handling", e); + } + } + + @Override + public void handleException(TransportException exp) { + future.completeExceptionally(exp); + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + + @Override + public MLTaskResponse read(StreamInput in) throws IOException { + return new MLTaskResponse(in); + } + }; + + StreamTransportService streamTransportService = TransportExecuteStreamTaskAction.getStreamTransportService(); + streamTransportService + .sendRequest( + clusterService.localNode(), + MLExecuteStreamTaskAction.NAME, + mlExecuteTaskRequest, + TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STREAM).build(), + handler + ); + + } catch (IOException e) { + throw new MLException("Got an exception in flux.", e); + } + + return Mono.fromCompletionStage(future); + }).doOnNext(channel::sendChunk).onErrorComplete(ex -> { + // Error handling + try { + channel.sendResponse(new BytesRestResponse(channel, (Exception) ex)); + return true; + } catch (final IOException e) { + throw new UncheckedIOException(e); + } + }).subscribe(); + }; + + return channel -> { + if (channel instanceof StreamingRestChannel) { + consumer.accept((StreamingRestChannel) channel); + } else { + final ActionRequestValidationException validationError = new ActionRequestValidationException(); + validationError.addValidationError("Unable to initiate request / response streaming over non-streaming channel"); + channel.sendResponse(new BytesRestResponse(channel, validationError)); + } + }; + } + + /** + * Creates a MLExecuteTaskRequest from a RestRequest + * + * @param request RestRequest + * @return MLExecuteTaskRequest + */ + @VisibleForTesting + MLExecuteTaskRequest getRequest(String agentId, RestRequest request, BytesReference content) throws IOException { + XContentParser parser = request + .getMediaType() + .xContent() + .createParser(request.getXContentRegistry(), LoggingDeprecationHandler.INSTANCE, content.streamInput()); + boolean async = isAsync(request); + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + + if (!mlFeatureEnabledSetting.isAgentFrameworkEnabled()) { + throw new IllegalStateException(AGENT_FRAMEWORK_DISABLED_ERR_MSG); + } + String tenantId = getTenantID(mlFeatureEnabledSetting.isMultiTenancyEnabled(), request); + FunctionName functionName = FunctionName.AGENT; + Input input = MLInput.parse(parser, functionName.name()); + AgentMLInput agentInput = (AgentMLInput) input; + agentInput.setAgentId(agentId); + agentInput.setTenantId(tenantId); + agentInput.setIsAsync(async); + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) agentInput.getInputDataset(); + inputDataSet.getParameters().put("stream", String.valueOf(true)); + return new MLExecuteTaskRequest(functionName, input); + } + + private HttpChunk convertToHttpChunk(MLTaskResponse response) throws IOException { + String memoryId = ""; + String parentInteractionId = ""; + String content = ""; + boolean isLast = false; + + // TODO: refactor to handle other types of agents + // Extract values from multiple tensors + try { + ModelTensorOutput output = (ModelTensorOutput) response.getOutput(); + if (output != null && !output.getMlModelOutputs().isEmpty()) { + ModelTensors modelTensors = output.getMlModelOutputs().get(0); + List tensors = modelTensors.getMlModelTensors(); + + for (ModelTensor tensor : tensors) { + String name = tensor.getName(); + if ("memory_id".equals(name) && tensor.getResult() != null) { + memoryId = tensor.getResult(); + } else if ("parent_interaction_id".equals(name) && tensor.getResult() != null) { + parentInteractionId = tensor.getResult(); + } else if (("llm_response".equals(name) || "response".equals(name)) && tensor.getDataAsMap() != null) { + Map dataMap = tensor.getDataAsMap(); + if (dataMap.containsKey("content")) { + content = (String) dataMap.get("content"); + if (content == null) + content = ""; + } + if (dataMap.containsKey("is_last")) { + isLast = Boolean.TRUE.equals(dataMap.get("is_last")); + } + } + } + } + } catch (Exception e) { + log.error("Failed to extract values from response", e); + } + + String finalContent = content; + boolean finalIsLast = isLast; + + log + .info( + "Converting to HttpChunk - memoryId: '{}', parentId: '{}', content: '{}', isLast: {}", + memoryId, + parentInteractionId, + content, + isLast + ); + + // Create ordered tensors + List orderedTensors = List + .of( + ModelTensor.builder().name("memory_id").result(memoryId).build(), + ModelTensor.builder().name("parent_interaction_id").result(parentInteractionId).build(), + ModelTensor.builder().name("response").dataAsMap(new LinkedHashMap() { + { + put("content", finalContent); + put("is_last", finalIsLast); + } + }).build() + ); + + ModelTensors tensors = ModelTensors.builder().mlModelTensors(orderedTensors).build(); + + ModelTensorOutput tensorOutput = ModelTensorOutput.builder().mlModelOutputs(List.of(tensors)).build(); + + XContentBuilder builder = XContentFactory.jsonBuilder(); + tensorOutput.toXContent(builder, ToXContent.EMPTY_PARAMS); + String jsonData = builder.toString(); + + String sseData = "data: " + jsonData + "\n\n"; + return createHttpChunk(sseData, isLast); + } + + private HttpChunk createHttpChunk(String sseData, boolean isLast) { + BytesReference bytesRef = BytesReference.fromByteBuffer(ByteBuffer.wrap(sseData.getBytes())); + return new HttpChunk() { + @Override + public void close() { + if (bytesRef instanceof Releasable) + ((Releasable) bytesRef).close(); + } + + @Override + public boolean isLast() { + return isLast; + } + + @Override + public BytesReference content() { + return bytesRef; + } + }; + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java index d57c98e3cf..73281e0333 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLExecuteTaskRunner.java @@ -7,14 +7,19 @@ import static org.opensearch.ml.common.settings.MLCommonsSettings.ML_COMMONS_ENABLE_INHOUSE_PYTHON_MODEL; import static org.opensearch.ml.plugin.MachineLearningPlugin.EXECUTE_THREAD_POOL; +import static org.opensearch.ml.plugin.MachineLearningPlugin.STREAM_EXECUTE_THREAD_POOL; + +import java.io.IOException; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.cluster.service.ClusterService; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.Input; +import org.opensearch.ml.common.transport.execute.MLExecuteStreamTaskAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskAction; import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; @@ -25,8 +30,12 @@ import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.StreamTransportResponseHandler; +import org.opensearch.transport.TransportChannel; +import org.opensearch.transport.TransportException; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.client.Client; +import org.opensearch.transport.stream.StreamTransportResponse; import lombok.extern.log4j.Log4j2; @@ -75,7 +84,12 @@ protected String getTransportActionName() { @Override protected String getTransportStreamActionName() { - return MLExecuteTaskAction.NAME; + return MLExecuteStreamTaskAction.NAME; + } + + @Override + protected boolean isStreamingRequest(MLExecuteTaskRequest request) { + return request.getStreamingChannel() != null; } @Override @@ -83,6 +97,45 @@ protected TransportResponseHandler getResponseHandler(Act return new ActionListenerResponseHandler<>(listener, MLExecuteTaskResponse::new); } + @Override + protected TransportResponseHandler getResponseStreamHandler(MLExecuteTaskRequest request) { + TransportChannel channel = request.getStreamingChannel(); + return new StreamTransportResponseHandler() { + @Override + public void handleStreamResponse(StreamTransportResponse streamResponse) { + try { + MLExecuteTaskResponse response; + while ((response = streamResponse.nextResponse()) != null) { + channel.sendResponseBatch(response); + } + channel.completeStream(); + streamResponse.close(); + } catch (Exception e) { + streamResponse.cancel("Stream error", e); + } + } + + @Override + public void handleException(TransportException exp) { + try { + channel.sendResponse(exp); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public String executor() { + return ThreadPool.Names.SAME; + } + + @Override + public MLExecuteTaskResponse read(StreamInput in) throws IOException { + return new MLExecuteTaskResponse(in); + } + }; + } + /** * Execute algorithm and return result. * @param request MLExecuteTaskRequest @@ -90,7 +143,9 @@ protected TransportResponseHandler getResponseHandler(Act */ @Override protected void executeTask(MLExecuteTaskRequest request, ActionListener listener) { - threadPool.executor(EXECUTE_THREAD_POOL).execute(() -> { + TransportChannel channel = request.getStreamingChannel(); + String threadPoolName = (channel != null) ? STREAM_EXECUTE_THREAD_POOL : EXECUTE_THREAD_POOL; + threadPool.executor(threadPoolName).execute(() -> { try { mlStats.getStat(MLNodeLevelStat.ML_EXECUTING_TASK_COUNT).increment(); mlStats.getStat(MLNodeLevelStat.ML_REQUEST_COUNT).increment(); @@ -111,7 +166,7 @@ protected void executeTask(MLExecuteTaskRequest request, ActionListener { MLExecuteTaskResponse response = new MLExecuteTaskResponse(functionName, output); listener.onResponse(response); - }, e -> { listener.onFailure(e); })); + }, e -> { listener.onFailure(e); }), channel); } catch (Exception e) { mlStats .createCounterStatIfAbsent(request.getFunctionName(), ActionName.EXECUTE, MLActionLevelStat.ML_ACTION_FAILURE_COUNT) diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 321ab7bc93..a40c76ba75 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -155,7 +155,9 @@ protected TransportResponseHandler getResponseHandler(ActionList return new ActionListenerResponseHandler<>(listener, MLTaskResponse::new); } - private TransportResponseHandler getResponseStreamHandler(TransportChannel channel) { + @Override + protected TransportResponseHandler getResponseStreamHandler(MLPredictionTaskRequest request) { + TransportChannel channel = request.getStreamingChannel(); return new StreamTransportResponseHandler() { @Override public void handleStreamResponse(StreamTransportResponse streamResponse) { @@ -228,7 +230,6 @@ public void dispatchTask( request.setDispatchTask(false); // Check if this is a streaming request if (isStreamingRequest(request)) { - TransportChannel channel = request.getStreamingChannel(); log.debug("Using streaming transport for request {}", request.getRequestID()); transportService .sendRequest( @@ -236,7 +237,7 @@ public void dispatchTask( getTransportStreamActionName(), request, TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STREAM).build(), - getResponseStreamHandler(channel) + getResponseStreamHandler(request) ); } else { transportService.sendRequest(node, getTransportActionName(), request, getResponseHandler(listener)); @@ -364,17 +365,9 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener parameters = inputDataSet.getParameters(); - return parameters != null && "true".equals(parameters.get("stream")); - } - } catch (Exception e) { - log.debug("Failed to check streaming parameter, defaulting to non-streaming", e); - } - return false; + @Override + protected boolean isStreamingRequest(MLPredictionTaskRequest request) { + return request.getStreamingChannel() != null; } private void executePredictionByInputDataType( diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java index 0193e6b7bb..4e68f79b26 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java @@ -22,6 +22,7 @@ import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.ml.stats.MLStats; +import org.opensearch.transport.TransportRequestOptions; import org.opensearch.transport.TransportResponseHandler; import org.opensearch.transport.TransportService; @@ -117,7 +118,19 @@ public void dispatchTask( // Execute ML task remotely log.debug("Execute ML request {} remotely on node {}", request.getRequestID(), nodeId); request.setDispatchTask(false); - transportService.sendRequest(node, getTransportActionName(), request, getResponseHandler(listener)); + if (isStreamingRequest(request)) { + log.debug("Using streaming transport for request {}", request.getRequestID()); + transportService + .sendRequest( + node, + getTransportStreamActionName(), + request, + TransportRequestOptions.builder().withType(TransportRequestOptions.Type.STREAM).build(), + getResponseStreamHandler(request) + ); + } else { + transportService.sendRequest(node, getTransportActionName(), request, getResponseHandler(listener)); + } } }, listener::onFailure)); } @@ -128,6 +141,14 @@ public void dispatchTask( protected abstract TransportResponseHandler getResponseHandler(ActionListener listener); + protected TransportResponseHandler getResponseStreamHandler(Request request) { + throw new UnsupportedOperationException("Streaming is not supported."); + } + + protected boolean isStreamingRequest(Request request) { + return false; + } + protected abstract void executeTask(Request request, ActionListener listener); protected void checkCBAndExecute(FunctionName functionName, Request request, ActionListener listener) { diff --git a/plugin/src/test/java/org/opensearch/ml/action/execute/TransportExecuteStreamTaskActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/execute/TransportExecuteStreamTaskActionTests.java new file mode 100644 index 0000000000..e38a61c183 --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/action/execute/TransportExecuteStreamTaskActionTests.java @@ -0,0 +1,118 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.action.execute; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.input.Input; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; +import org.opensearch.ml.task.MLExecuteTaskRunner; +import org.opensearch.tasks.Task; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.StreamTransportService; +import org.opensearch.transport.TransportChannel; +import org.opensearch.transport.TransportService; +import org.opensearch.transport.client.Client; + +public class TransportExecuteStreamTaskActionTests extends OpenSearchTestCase { + + @Mock + private MLExecuteTaskRunner mlExecuteTaskRunner; + + @Mock + private TransportService transportService; + + @Mock + private Client client; + + @Mock + private ClusterService clusterService; + + @Mock + private ActionListener actionListener; + + @Mock + private ActionFilters actionFilters; + + @Mock + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + + @Mock + private StreamTransportService streamTransportService; + + @Mock + private TransportChannel transportChannel; + + @Mock + private ThreadPool threadPool; + + private MLExecuteTaskRequest mlExecuteTaskRequest; + private TransportExecuteStreamTaskAction transportExecuteStreamTaskAction; + private ThreadContext threadContext; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + + Settings settings = Settings.builder().build(); + threadContext = new ThreadContext(settings); + + when(client.threadPool()).thenReturn(threadPool); + when(threadPool.getThreadContext()).thenReturn(threadContext); + when(clusterService.getSettings()).thenReturn(settings); + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(true); + + mlExecuteTaskRequest = MLExecuteTaskRequest.builder().functionName(FunctionName.AGENT).input(mock(Input.class)).build(); + + transportExecuteStreamTaskAction = spy( + new TransportExecuteStreamTaskAction(transportService, actionFilters, mlExecuteTaskRunner, streamTransportService) + ); + } + + @Test + public void testGetStreamTransportService() { + StreamTransportService result = TransportExecuteStreamTaskAction.getStreamTransportService(); + assertNotNull(result); + } + + @Test + public void testMessageReceived() { + Task task = mock(Task.class); + transportExecuteStreamTaskAction.messageReceived(mlExecuteTaskRequest, transportChannel, task); + + verify(transportExecuteStreamTaskAction).doExecute(eq(task), eq(mlExecuteTaskRequest), any(), eq(transportChannel)); + } + + @Test + public void testDoExecuteWithoutChannel() { + transportExecuteStreamTaskAction.doExecute(null, mlExecuteTaskRequest, actionListener); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Exception.class); + verify(actionListener).onFailure(captor.capture()); + assertTrue(captor.getValue() instanceof UnsupportedOperationException); + assertEquals("Use doExecute with TransportChannel for streaming requests", captor.getValue().getMessage()); + } + +} diff --git a/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java index 0575ba1e88..a037dcfdc7 100644 --- a/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java +++ b/plugin/src/test/java/org/opensearch/ml/plugin/MachineLearningPluginTests.java @@ -233,7 +233,7 @@ public void testGetExecutorBuilders() { Settings settings = Settings.EMPTY; List> executorBuilders = plugin.getExecutorBuilders(settings); assertNotNull(executorBuilders); - assertEquals(12, executorBuilders.size()); + assertEquals(13, executorBuilders.size()); // Verify we have the expected number of thread pools assertTrue(executorBuilders.size() > 5); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteStreamActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteStreamActionTests.java new file mode 100644 index 0000000000..02536d506d --- /dev/null +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLExecuteStreamActionTests.java @@ -0,0 +1,168 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.rest; + +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; +import static org.opensearch.ml.utils.MLExceptionUtils.STREAM_DISABLED_ERR_MSG; +import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_AGENT_ID; +import static org.opensearch.ml.utils.TestHelper.getExecuteAgentStreamRestRequest; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; +import org.opensearch.ml.common.input.Input; +import org.opensearch.ml.common.input.execute.agent.AgentMLInput; +import org.opensearch.ml.common.settings.MLFeatureEnabledSetting; +import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; +import org.opensearch.rest.RestChannel; +import org.opensearch.rest.RestRequest; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestRequest; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.client.node.NodeClient; + +public class RestMLExecuteStreamActionTests extends OpenSearchTestCase { + + NodeClient client; + private ThreadPool threadPool; + + @Mock + RestChannel channel; + + private RestMLExecuteStreamAction restAction; + private MLFeatureEnabledSetting mlFeatureEnabledSetting; + private ClusterService clusterService; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + mlFeatureEnabledSetting = mock(MLFeatureEnabledSetting.class); + clusterService = mock(ClusterService.class); + restAction = new RestMLExecuteStreamAction(mlFeatureEnabledSetting, clusterService); + threadPool = new TestThreadPool(this.getClass().getSimpleName() + "ThreadPool"); + client = spy(new NodeClient(Settings.EMPTY, threadPool)); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + threadPool.shutdown(); + client.close(); + } + + @Test + public void testGetName() { + assertEquals("ml_execute_stream_action", restAction.getName()); + } + + @Test + public void testRoutes() { + List routes = restAction.routes(); + assertEquals(1, routes.size()); + + assertEquals(RestRequest.Method.POST, routes.get(0).getMethod()); + assertTrue(routes.get(0).getPath().contains("/agents/")); + assertTrue(routes.get(0).getPath().contains("/_execute/stream")); + } + + @Test + public void testConstructor() { + assertNotNull(restAction); + RestMLExecuteStreamAction newAction = new RestMLExecuteStreamAction(mlFeatureEnabledSetting, clusterService); + assertNotNull(newAction); + assertEquals("ml_execute_stream_action", newAction.getName()); + } + + @Test + public void testSupportsContentStream() { + assertTrue(restAction.supportsContentStream()); + } + + @Test + public void testSupportsStreaming() { + assertTrue(restAction.supportsStreaming()); + } + + @Test + public void testAllowsUnsafeBuffers() { + assertTrue(restAction.allowsUnsafeBuffers()); + } + + @Test + public void testPrepareRequestWhenStreamEnabled() throws IOException { + when(mlFeatureEnabledSetting.isStreamEnabled()).thenReturn(true); + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(true); + + RestRequest request = getExecuteAgentStreamRestRequest(); + assertNotNull(restAction.prepareRequest(request, client)); + } + + @Test + public void testPrepareRequestWhenStreamDisabled() throws IOException { + when(mlFeatureEnabledSetting.isStreamEnabled()).thenReturn(false); + RestRequest request = getExecuteAgentStreamRestRequest(); + + IllegalStateException exception = assertThrows( + IllegalStateException.class, + () -> restAction.prepareRequest(request, null) + ); + assertEquals(STREAM_DISABLED_ERR_MSG, exception.getMessage()); + } + + @Test + public void testGetRequestAgent() throws IOException { + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(true); + + Map params = new HashMap<>(); + params.put(PARAMETER_AGENT_ID, "test_agent_id"); + final String requestContent = "{\"parameters\":{\"question\":\"test question\"}}"; + + RestRequest request = new FakeRestRequest.Builder(xContentRegistry()) + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .withPath("/_plugins/_ml/agents/test_agent_id/_execute/stream") + .build(); + + String agentId = "test_agent_id"; + + MLExecuteTaskRequest executeTaskRequest = restAction.getRequest(agentId, request, request.content()); + + Input input = executeTaskRequest.getInput(); + assertNotNull(input); + assertEquals(FunctionName.AGENT, input.getFunctionName()); + + AgentMLInput agentInput = (AgentMLInput) input; + assertEquals(agentId, agentInput.getAgentId()); + + RemoteInferenceInputDataSet inputDataSet = (RemoteInferenceInputDataSet) agentInput.getInputDataset(); + assertNotNull(inputDataSet); + assertEquals("true", inputDataSet.getParameters().get("stream")); + } + + @Test + public void testGetRequestAgentFrameworkDisabled() { + RestRequest request = getExecuteAgentStreamRestRequest(); + + when(mlFeatureEnabledSetting.isAgentFrameworkEnabled()).thenReturn(false); + assertThrows(IllegalStateException.class, () -> restAction.handleRequest(request, channel, client)); + } +} diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java index 5883313cfc..1f53744661 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLExecuteTaskRunnerTests.java @@ -8,6 +8,7 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; @@ -36,6 +37,7 @@ import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.execute.samplecalculator.LocalSampleCalculatorInput; +import org.opensearch.ml.common.output.Output; import org.opensearch.ml.common.transport.execute.MLExecuteTaskRequest; import org.opensearch.ml.common.transport.execute.MLExecuteTaskResponse; import org.opensearch.ml.engine.MLEngine; @@ -93,7 +95,7 @@ public class MLExecuteTaskRunnerTests extends OpenSearchTestCase { public void setup() { MockitoAnnotations.openMocks(this); encryptor = new EncryptorImpl(null, "m+dWmfmnNRiNlOdej/QelEkvMTyH//frS2TBeS2BP4w="); - mlEngine = new MLEngine(Path.of("/tmp/djl-cache/" + randomAlphaOfLength(10)), encryptor); + mlEngine = spy(new MLEngine(Path.of("/tmp/djl-cache/" + randomAlphaOfLength(10)), encryptor)); when(threadPool.executor(anyString())).thenReturn(executorService); doAnswer(invocation -> { Runnable runnable = invocation.getArgument(0); @@ -146,7 +148,14 @@ public void setup() { ); } - public void testExecuteTask_Success() { + public void testExecuteTask_Success() throws Exception { + Output mockOutput = mock(Output.class); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onResponse(mockOutput); + return null; + }).when(mlEngine).execute(any(), any(), any()); + taskRunner.executeTask(mlExecuteTaskRequest, listener); verify(listener).onResponse(any(MLExecuteTaskResponse.class)); } diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java index b8318a296e..6ac4dc9262 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLPredictTaskRunnerTests.java @@ -89,6 +89,7 @@ import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportChannel; import org.opensearch.transport.TransportService; import org.opensearch.transport.client.Client; @@ -606,12 +607,14 @@ public void testIsStreamingRequest() { } } - public void testIsStreamingRequestWithStreamParameter() { - Map parameters = new HashMap<>(); - parameters.put("stream", "true"); - RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); - MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(); + public void testIsStreamingRequestWithChannel() { + MLInput mlInput = MLInput + .builder() + .algorithm(FunctionName.REMOTE) + .inputDataset(new TextDocsInputDataSet(List.of("test"), null)) + .build(); MLPredictionTaskRequest request = MLPredictionTaskRequest.builder().modelId("test").mlInput(mlInput).build(); + request.setStreamingChannel(mock(TransportChannel.class)); try { java.lang.reflect.Method method = MLPredictTaskRunner.class diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index e362c1fed8..71b038526c 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -422,6 +422,22 @@ public static RestRequest getExecuteAgentRestRequest() { .build(); } + public static RestRequest getExecuteAgentStreamRestRequest() { + Map params = new HashMap<>(); + params.put(PARAMETER_AGENT_ID, "test_agent_id"); + final String requestContent = "{\"name\":\"Test_Agent_For_RAG\",\"type\":\"flow\"," + + "\"description\":\"this is a test agent\",\"app_type\":\"my app\"," + + "\"tools\":[{\"type\":\"ListIndexTool\",\"name\":\"ListIndexTool\"," + + "\"description\":\"Use this tool to get OpenSearch index information: " + + "(health, status, index, uuid, primary count, replica count, docs.count, docs.deleted, " + + "store.size, primary.store.size).\",\"include_output_in_agent_response\":true}]}"; + return new FakeRestRequest.Builder(getXContentRegistry()) + .withParams(params) + .withContent(new BytesArray(requestContent), XContentType.JSON) + .withPath("/_plugins/_ml/agents/test_agent_id/_execute/stream") + .build(); + } + public static RestRequest getExecuteToolRestRequest() { Map params = new HashMap<>(); params.put(PARAMETER_TOOL_NAME, "TestTool");