Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ buildscript {
ext {
opensearch_group = "org.opensearch"
isSnapshot = "true" == System.getProperty("build.snapshot", "true")
opensearch_version = System.getProperty("opensearch.version", "3.4.0-SNAPSHOT")
opensearch_version = System.getProperty("opensearch.version", "3.3.0-SNAPSHOT")
buildVersionQualifier = System.getProperty("build.version_qualifier", "")
asm_version = "9.7"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseLLMOutput;
import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.substitute;
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX;
import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.SUMMARY_PROMPT_TEMPLATE;
import static org.opensearch.ml.engine.tools.ReadFromScratchPadTool.SCRATCHPAD_NOTES_KEY;

import java.security.PrivilegedActionException;
Expand All @@ -57,17 +58,22 @@
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.agent.LLMSpec;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.agent.MLToolSpec;
import org.opensearch.ml.common.conversation.Interaction;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.memory.Memory;
import org.opensearch.ml.common.spi.memory.Message;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.encryptor.Encryptor;
import org.opensearch.ml.engine.function_calling.FunctionCalling;
Expand Down Expand Up @@ -125,6 +131,8 @@ public class MLChatAgentRunner implements MLAgentRunner {

private static final String DEFAULT_MAX_ITERATIONS = "10";
private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task";
private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = MAX_ITERATIONS_MESSAGE
+ ". Here's a summary of the steps completed so far:\n\n%s";

private Client client;
private Settings settings;
Expand Down Expand Up @@ -321,7 +329,6 @@ private void runReAct(

StringBuilder scratchpadBuilder = new StringBuilder();
List<String> interactions = new CopyOnWriteArrayList<>();

StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}");
AtomicReference<String> newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt));
tmpParameters.put(PROMPT, newPrompt.get());
Expand Down Expand Up @@ -413,7 +420,10 @@ private void runReAct(
additionalInfo,
lastThought,
maxIterations,
tools
tools,
tmpParameters,
llm,
tenantId
);
return;
}
Expand Down Expand Up @@ -514,7 +524,10 @@ private void runReAct(
additionalInfo,
lastThought,
maxIterations,
tools
tools,
tmpParameters,
llm,
tenantId
);
return;
}
Expand Down Expand Up @@ -887,11 +900,64 @@ private void handleMaxIterationsReached(
Map<String, Object> additionalInfo,
AtomicReference<String> lastThought,
int maxIterations,
Map<String, Tool> tools,
Map<String, String> parameters,
LLMSpec llmSpec,
String tenantId
) {
ActionListener<String> responseListener = ActionListener.wrap(response -> {
sendTraditionalMaxIterationsResponse(
sessionId,
listener,
question,
parentInteractionId,
verbose,
traceDisabled,
traceTensors,
conversationIndexMemory,
traceNumber,
additionalInfo,
response,
tools
);
}, listener::onFailure);

generateLLMSummary(
traceTensors,
llmSpec,
tenantId,
ActionListener
.wrap(
summary -> responseListener
.onResponse(String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary)),
e -> {
log.error("Failed to generate LLM summary, using fallback strategy", e);
String fallbackResponse = (lastThought.get() != null
&& !lastThought.get().isEmpty()
&& !"null".equals(lastThought.get()))
? String
.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get())
: String.format(MAX_ITERATIONS_MESSAGE, maxIterations);
responseListener.onResponse(fallbackResponse);
}
)
);
}

private void sendTraditionalMaxIterationsResponse(
String sessionId,
ActionListener<Object> listener,
String question,
String parentInteractionId,
boolean verbose,
boolean traceDisabled,
List<ModelTensors> traceTensors,
ConversationIndexMemory conversationIndexMemory,
AtomicInteger traceNumber,
Map<String, Object> additionalInfo,
String response,
Map<String, Tool> tools
) {
String incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get()))
? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get())
: String.format(MAX_ITERATIONS_MESSAGE, maxIterations);
sendFinalAnswer(
sessionId,
listener,
Expand All @@ -903,11 +969,76 @@ private void handleMaxIterationsReached(
conversationIndexMemory,
traceNumber,
additionalInfo,
incompleteResponse
response
);
cleanUpResource(tools);
}

void generateLLMSummary(List<ModelTensors> stepsSummary, LLMSpec llmSpec, String tenantId, ActionListener<String> listener) {
if (stepsSummary == null || stepsSummary.isEmpty()) {
listener.onFailure(new IllegalArgumentException("Steps summary cannot be null or empty"));
return;
}

try {
Map<String, String> summaryParams = new HashMap<>();
if (llmSpec.getParameters() != null) {
summaryParams.putAll(llmSpec.getParameters());
}

// Convert ModelTensors to strings before joining
List<String> stepStrings = new ArrayList<>();
for (ModelTensors tensor : stepsSummary) {
if (tensor != null && tensor.getMlModelTensors() != null) {
for (ModelTensor modelTensor : tensor.getMlModelTensors()) {
if (modelTensor.getResult() != null) {
stepStrings.add(modelTensor.getResult());
} else if (modelTensor.getDataAsMap() != null && modelTensor.getDataAsMap().containsKey("response")) {
stepStrings.add(String.valueOf(modelTensor.getDataAsMap().get("response")));
}
}
}
}
String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings));
summaryParams.put(PROMPT, summaryPrompt);
summaryParams.putIfAbsent(SYSTEM_PROMPT_FIELD, "");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this being added as empty?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Hailong-am This change was made based on hailong’s previous comment.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you test this? remote model providers generally throw exception on empty system prompt

lets add a simple default prompt instead of empty prompt?


ActionRequest request = new MLPredictionTaskRequest(
llmSpec.getModelId(),
RemoteInferenceMLInput
.builder()
.algorithm(FunctionName.REMOTE)
.inputDataset(RemoteInferenceInputDataSet.builder().parameters(summaryParams).build())
.build(),
null,
tenantId
);
client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> {
String summary = extractSummaryFromResponse(response);
if (summary == null) {
listener.onFailure(new RuntimeException("Empty or invalid LLM summary response"));
return;
}
listener.onResponse(summary);
}, listener::onFailure));
} catch (Exception e) {
listener.onFailure(e);
}
}

public String extractSummaryFromResponse(MLTaskResponse response) {
try {
String outputString = outputToOutputString(response.getOutput());
if (outputString != null && !outputString.trim().isEmpty()) {
return outputString.trim();
}
return null;
} catch (Exception e) {
log.error("Failed to extract summary from response", e);
throw new RuntimeException("Failed to extract summary from response", e);
}
}

private void saveMessage(
ConversationIndexMemory memory,
String question,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,7 @@ public class PromptTemplate {
- Avoid making assumptions and relying on implicit knowledge.
- Your response must be self-contained and ready for the planner to use without modification. Never end with a question.
- Break complex searches into simpler queries when appropriate.""";

public static final String SUMMARY_PROMPT_TEMPLATE =
"Please provide a concise summary of the following agent execution steps. Focus on what the agent was trying to accomplish and what progress was made:\n\n%s";
}
Original file line number Diff line number Diff line change
Expand Up @@ -1171,4 +1171,142 @@ public void testConstructLLMParams_DefaultValues() {
Assert.assertTrue(result.containsKey(AgentUtils.RESPONSE_FORMAT_INSTRUCTION));
Assert.assertTrue(result.containsKey(AgentUtils.TOOL_RESPONSE));
}

@Test
public void testMaxIterationsWithSummaryEnabled() {
// Create LLM spec with max_iteration = 1 to simplify test
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build();
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.type(MLAgentType.CONVERSATIONAL.name())
.llm(llmSpec)
.memory(mlMemorySpec)
.tools(Arrays.asList(firstToolSpec))
.build();

// Reset and setup fresh mocks
Mockito.reset(client);
Mockito.reset(firstTool);
when(firstTool.getName()).thenReturn(FIRST_TOOL);
when(firstTool.validate(Mockito.anyMap())).thenReturn(true);
Mockito.doAnswer(generateToolResponse("First tool response")).when(firstTool).run(Mockito.anyMap(), any());

// First call: LLM response without final_answer to trigger max iterations
// Second call: Summary LLM response with result field instead of dataAsMap
Mockito
.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to analyze the data", "action", FIRST_TOOL)))
.doAnswer(invocation -> {
ActionListener<Object> listener = invocation.getArgument(2);
ModelTensor modelTensor = ModelTensor.builder().result("Summary: Analysis step was attempted").build();
ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build();
ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build();
MLTaskResponse mlTaskResponse = MLTaskResponse.builder().output(mlModelTensorOutput).build();
listener.onResponse(mlTaskResponse);
return null;
})
.when(client)
.execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));

Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");
mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Verify response is captured
verify(agentActionListener).onResponse(objectCaptor.capture());
Object capturedResponse = objectCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);

ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
assertEquals(1, agentOutput.size());

// Verify the response contains summary message
String response = (String) agentOutput.get(0).getDataAsMap().get("response");
assertTrue(
response
.startsWith(
"Agent reached maximum iterations (1) without completing the task. Here's a summary of the steps completed so far:"
)
);
assertTrue(response.contains("Summary: Analysis step was attempted"));
}

@Test
public void testMaxIterationsWithSummaryDisabled() {
// Create LLM spec with max_iteration = 1
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").parameters(Map.of("max_iteration", "1")).build();
MLToolSpec firstToolSpec = MLToolSpec.builder().name(FIRST_TOOL).type(FIRST_TOOL).build();
final MLAgent mlAgent = MLAgent
.builder()
.name("TestAgent")
.type(MLAgentType.CONVERSATIONAL.name())
.llm(llmSpec)
.memory(mlMemorySpec)
.tools(Arrays.asList(firstToolSpec))
.build();

// Reset client mock for this test
Mockito.reset(client);
// First call: LLM response without final_answer to trigger max iterations
// Second call: Summary LLM fails
Mockito.doAnswer(getLLMAnswer(ImmutableMap.of("thought", "I need to use the tool", "action", FIRST_TOOL))).doAnswer(invocation -> {
ActionListener<Object> listener = invocation.getArgument(2);
listener.onFailure(new RuntimeException("LLM summary generation failed"));
return null;
}).when(client).execute(any(ActionType.class), any(ActionRequest.class), isA(ActionListener.class));

Map<String, String> params = new HashMap<>();
params.put(MLAgentExecutor.PARENT_INTERACTION_ID, "parent_interaction_id");

mlChatAgentRunner.run(mlAgent, params, agentActionListener);

// Verify response is captured
verify(agentActionListener).onResponse(objectCaptor.capture());
Object capturedResponse = objectCaptor.getValue();
assertTrue(capturedResponse instanceof ModelTensorOutput);

ModelTensorOutput modelTensorOutput = (ModelTensorOutput) capturedResponse;
List<ModelTensor> agentOutput = modelTensorOutput.getMlModelOutputs().get(1).getMlModelTensors();
assertEquals(1, agentOutput.size());

// Verify the response uses fallback strategy with last thought
String response = (String) agentOutput.get(0).getDataAsMap().get("response");
assertEquals("Agent reached maximum iterations (1) without completing the task. Last thought: I need to use the tool", response);
}

@Test
public void testExtractSummaryFromResponse() {
MLTaskResponse response = MLTaskResponse
.builder()
.output(
ModelTensorOutput
.builder()
.mlModelOutputs(
Arrays
.asList(
ModelTensors
.builder()
.mlModelTensors(Arrays.asList(ModelTensor.builder().result("Valid summary text").build()))
.build()
)
)
.build()
)
.build();

String result = mlChatAgentRunner.extractSummaryFromResponse(response);
assertEquals("Valid summary text", result);
}

@Test
public void testGenerateLLMSummaryWithNullSteps() {
LLMSpec llmSpec = LLMSpec.builder().modelId("MODEL_ID").build();
ActionListener<String> listener = Mockito.mock(ActionListener.class);

mlChatAgentRunner.generateLLMSummary(null, llmSpec, "tenant", listener);

verify(listener).onFailure(any(IllegalArgumentException.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ public void checkModelGroupPermission(MLModelGroup mlModelGroup, User user, Acti
*/
public static boolean shouldUseResourceAuthz(String resourceType) {
var client = ResourceSharingClientAccessor.getInstance().getResourceSharingClient();
return client != null && client.isFeatureEnabledForType(resourceType);
return client != null;
}

public boolean skipModelAccessControl(User user) {
Expand Down