Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringEscapeUtils;
import org.apache.commons.text.StringSubstitutor;
import org.opensearch.Version;
Expand Down Expand Up @@ -56,6 +57,8 @@ public class HttpConnector extends AbstractConnector {
public static final String PARAMETERS_FIELD = "parameters";
public static final String SERVICE_NAME_FIELD = "service_name";
public static final String REGION_FIELD = "region";
// TODO: move the AgentUtils class from algorithm module to common module
public static final String LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS = "openai/v1/chat/completions";

// TODO: add RequestConfig like request time out,

Expand Down Expand Up @@ -377,14 +380,14 @@ private boolean neededStreamParameterInPayload(Map<String, String> parameters) {
}

String llmInterface = parameters.get("_llm_interface");
if (llmInterface.isBlank()) {
if (StringUtils.isBlank(llmInterface)) {
return false;
}

llmInterface = llmInterface.trim().toLowerCase(Locale.ROOT);
llmInterface = StringEscapeUtils.unescapeJava(llmInterface);
switch (llmInterface) {
case "openai/v1/chat/completions":
case LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS:
return true;
default:
return false;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.execute;

import org.opensearch.action.ActionType;

public class MLExecuteStreamTaskAction extends ActionType<MLExecuteTaskResponse> {
public static final MLExecuteStreamTaskAction INSTANCE = new MLExecuteStreamTaskAction();
public static final String NAME = "cluster:admin/opensearch/ml/execute/stream";

private MLExecuteStreamTaskAction() {
super(NAME, MLExecuteTaskResponse::new);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,27 @@
import org.opensearch.ml.common.MLCommonsClassLoader;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.transport.MLTaskRequest;
import org.opensearch.transport.TransportChannel;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.ToString;
import lombok.experimental.FieldDefaults;
import lombok.experimental.NonFinal;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@ToString
public class MLExecuteTaskRequest extends MLTaskRequest {

@Getter
@Setter
@NonFinal
private transient TransportChannel streamingChannel;

FunctionName functionName;
Input input;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,24 @@
import org.opensearch.ml.common.exception.ExecuteException;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.output.Output;
import org.opensearch.transport.TransportChannel;

public interface Executable {

/**
* Execute algorithm with given input data.
* Execute algorithm with given input data (non-streaming).
* @param input input data
* @param listener action listener
*/
void execute(Input input, ActionListener<Output> listener) throws ExecuteException;
default void execute(Input input, ActionListener<Output> listener) throws ExecuteException {
execute(input, listener, null);
}

/**
* Execute algorithm with given input data (streaming).
* @param input input data
* @param listener action listener
* @param channel transport channel
*/
default void execute(Input input, ActionListener<Output> listener, TransportChannel channel) {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.Output;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.transport.TransportChannel;

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
Expand Down Expand Up @@ -186,7 +187,7 @@ public MLOutput trainAndPredict(Input input) {
return trainAndPredictable.trainAndPredict(mlInput);
}

public void execute(Input input, ActionListener<Output> listener) throws Exception {
public void execute(Input input, ActionListener<Output> listener, TransportChannel channel) throws Exception {
validateInput(input);
if (input.getFunctionName() == FunctionName.METRICS_CORRELATION) {
MLExecutable executable = MLEngineClassLoader.initInstance(input.getFunctionName(), input, Input.class);
Expand All @@ -199,6 +200,10 @@ public void execute(Input input, ActionListener<Output> listener) throws Excepti
if (executable == null) {
throw new IllegalArgumentException("Unsupported executable function: " + input.getFunctionName());
}
if (channel != null) {
executable.execute(input, listener, channel);
return;
}
executable.execute(input, listener);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@
import org.opensearch.remote.metadata.client.SdkClient;
import org.opensearch.remote.metadata.common.SdkClientUtils;
import org.opensearch.search.fetch.subphase.FetchSourceContext;
import org.opensearch.transport.TransportChannel;
import org.opensearch.transport.client.Client;

import com.google.common.annotations.VisibleForTesting;
Expand Down Expand Up @@ -143,7 +144,7 @@ public void onMultiTenancyEnabledChanged(boolean isEnabled) {
}

@Override
public void execute(Input input, ActionListener<Output> listener) {
public void execute(Input input, ActionListener<Output> listener, TransportChannel channel) {
if (!(input instanceof AgentMLInput)) {
throw new IllegalArgumentException("wrong input");
}
Expand Down Expand Up @@ -271,7 +272,8 @@ public void execute(Input input, ActionListener<Output> listener) {
isAsync,
outputs,
modelTensors,
mlAgent
mlAgent,
channel
);
}, e -> {
log.error("Failed to get existing interaction for regeneration", e);
Expand All @@ -287,7 +289,8 @@ public void execute(Input input, ActionListener<Output> listener) {
isAsync,
outputs,
modelTensors,
mlAgent
mlAgent,
channel
);
}
}, ex -> {
Expand Down Expand Up @@ -318,7 +321,8 @@ public void execute(Input input, ActionListener<Output> listener) {
outputs,
modelTensors,
listener,
createdMemory
createdMemory,
channel
),
ex -> {
log.error("Failed to find memory with memory_id: {}", memoryId, ex);
Expand All @@ -329,7 +333,6 @@ public void execute(Input input, ActionListener<Output> listener) {
return;
}
}

executeAgent(
inputDataSet,
mlTask,
Expand All @@ -339,7 +342,8 @@ public void execute(Input input, ActionListener<Output> listener) {
outputs,
modelTensors,
listener,
null
null,
channel
);
}
} catch (Exception e) {
Expand Down Expand Up @@ -382,7 +386,8 @@ private void saveRootInteractionAndExecute(
boolean isAsync,
List<ModelTensors> outputs,
List<ModelTensor> modelTensors,
MLAgent mlAgent
MLAgent mlAgent,
TransportChannel channel
) {
String appType = mlAgent.getAppType();
String question = inputDataSet.getParameters().get(QUESTION);
Expand Down Expand Up @@ -416,7 +421,8 @@ private void saveRootInteractionAndExecute(
outputs,
modelTensors,
listener,
memory
memory,
channel
),
e -> {
log.error("Failed to regenerate for interaction {}", regenerateInteractionId, e);
Expand All @@ -425,7 +431,18 @@ private void saveRootInteractionAndExecute(
)
);
} else {
executeAgent(inputDataSet, mlTask, isAsync, memory.getConversationId(), mlAgent, outputs, modelTensors, listener, memory);
executeAgent(
inputDataSet,
mlTask,
isAsync,
memory.getConversationId(),
mlAgent,
outputs,
modelTensors,
listener,
memory,
channel
);
}
}, ex -> {
log.error("Failed to create parent interaction", ex);
Expand All @@ -442,7 +459,8 @@ private void executeAgent(
List<ModelTensors> outputs,
List<ModelTensor> modelTensors,
ActionListener<Output> listener,
ConversationIndexMemory memory
ConversationIndexMemory memory,
TransportChannel channel
) {
String mcpConnectorConfigJSON = (mlAgent.getParameters() != null) ? mlAgent.getParameters().get(MCP_CONNECTORS_FIELD) : null;
if (mcpConnectorConfigJSON != null && !mlFeatureEnabledSetting.isMcpConnectorEnabled()) {
Expand Down Expand Up @@ -494,7 +512,7 @@ private void executeAgent(
memory
);
inputDataSet.getParameters().put(TASK_ID_FIELD, taskId);
mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener, channel);
}, e -> {
log.error("Failed to create task for agent async execution", e);
listener.onFailure(e);
Expand All @@ -508,7 +526,7 @@ private void executeAgent(
parentInteractionId,
memory
);
mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener);
mlAgentRunner.run(mlAgent, inputDataSet.getParameters(), agentActionListener, channel);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,29 @@

import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.transport.TransportChannel;

/**
* Agent executor interface definition. Agent executor will be used by {@link MLAgentExecutor} to invoke agents.
*/
public interface MLAgentRunner {

/**
* Function interface to execute agent.
* Function interface to execute agent (non-streaming)
* @param mlAgent
* @param params
* @param listener
*/
void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener);
default void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener) {
run(mlAgent, params, listener, null);
}

/**
* Function interface to execute agent (streaming)
* @param mlAgent
* @param params
* @param listener
* @param channel
*/
void run(MLAgent mlAgent, Map<String, String> params, ActionListener<Object> listener, TransportChannel channel);
}
Loading
Loading