-
Couldn't load subscription status.
- Fork 184
FEATURE: Summarize the steps when max steps limit reached #4184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
af5ea63
51dbde8
45fb807
86e65a1
24c90ec
f20ae66
6b57c8e
e0aa50f
eeeea3f
865d8ae
bdbb695
bff9bb5
750904e
2ddad4c
8de0f73
9292f3e
c7f646a
8ca2b93
21677cc
4234aac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,6 +34,7 @@ | |
| import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.parseLLMOutput; | ||
| import static org.opensearch.ml.engine.algorithms.agent.AgentUtils.substitute; | ||
| import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.CHAT_HISTORY_PREFIX; | ||
| import static org.opensearch.ml.engine.algorithms.agent.PromptTemplate.SUMMARY_PROMPT_TEMPLATE; | ||
| import static org.opensearch.ml.engine.tools.ReadFromScratchPadTool.SCRATCHPAD_NOTES_KEY; | ||
|
|
||
| import java.security.PrivilegedActionException; | ||
|
|
@@ -57,17 +58,22 @@ | |
| import org.opensearch.core.action.ActionListener; | ||
| import org.opensearch.core.common.Strings; | ||
| import org.opensearch.core.xcontent.NamedXContentRegistry; | ||
| import org.opensearch.ml.common.FunctionName; | ||
| import org.opensearch.ml.common.agent.LLMSpec; | ||
| import org.opensearch.ml.common.agent.MLAgent; | ||
| import org.opensearch.ml.common.agent.MLToolSpec; | ||
| import org.opensearch.ml.common.conversation.Interaction; | ||
| import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; | ||
| import org.opensearch.ml.common.input.remote.RemoteInferenceMLInput; | ||
| import org.opensearch.ml.common.output.model.ModelTensor; | ||
| import org.opensearch.ml.common.output.model.ModelTensorOutput; | ||
| import org.opensearch.ml.common.output.model.ModelTensors; | ||
| import org.opensearch.ml.common.spi.memory.Memory; | ||
| import org.opensearch.ml.common.spi.memory.Message; | ||
| import org.opensearch.ml.common.spi.tools.Tool; | ||
| import org.opensearch.ml.common.transport.MLTaskResponse; | ||
| import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; | ||
| import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; | ||
| import org.opensearch.ml.common.utils.StringUtils; | ||
| import org.opensearch.ml.engine.encryptor.Encryptor; | ||
| import org.opensearch.ml.engine.function_calling.FunctionCalling; | ||
|
|
@@ -125,6 +131,8 @@ public class MLChatAgentRunner implements MLAgentRunner { | |
|
|
||
| private static final String DEFAULT_MAX_ITERATIONS = "10"; | ||
| private static final String MAX_ITERATIONS_MESSAGE = "Agent reached maximum iterations (%d) without completing the task"; | ||
| private static final String MAX_ITERATIONS_SUMMARY_MESSAGE = MAX_ITERATIONS_MESSAGE | ||
| + ". Here's a summary of the steps completed so far:\n\n%s"; | ||
|
|
||
| private Client client; | ||
| private Settings settings; | ||
|
|
@@ -321,7 +329,6 @@ private void runReAct( | |
|
|
||
| StringBuilder scratchpadBuilder = new StringBuilder(); | ||
| List<String> interactions = new CopyOnWriteArrayList<>(); | ||
|
|
||
| StringSubstitutor tmpSubstitutor = new StringSubstitutor(Map.of(SCRATCHPAD, scratchpadBuilder.toString()), "${parameters.", "}"); | ||
| AtomicReference<String> newPrompt = new AtomicReference<>(tmpSubstitutor.replace(prompt)); | ||
| tmpParameters.put(PROMPT, newPrompt.get()); | ||
|
|
@@ -413,7 +420,10 @@ private void runReAct( | |
| additionalInfo, | ||
| lastThought, | ||
| maxIterations, | ||
| tools | ||
| tools, | ||
| tmpParameters, | ||
| llm, | ||
| tenantId | ||
| ); | ||
| return; | ||
| } | ||
|
|
@@ -514,7 +524,10 @@ private void runReAct( | |
| additionalInfo, | ||
| lastThought, | ||
| maxIterations, | ||
| tools | ||
| tools, | ||
| tmpParameters, | ||
| llm, | ||
| tenantId | ||
| ); | ||
| return; | ||
| } | ||
|
|
@@ -887,11 +900,64 @@ private void handleMaxIterationsReached( | |
| Map<String, Object> additionalInfo, | ||
| AtomicReference<String> lastThought, | ||
| int maxIterations, | ||
| Map<String, Tool> tools, | ||
| Map<String, String> parameters, | ||
| LLMSpec llmSpec, | ||
| String tenantId | ||
| ) { | ||
| ActionListener<String> responseListener = ActionListener.wrap(response -> { | ||
| sendTraditionalMaxIterationsResponse( | ||
| sessionId, | ||
| listener, | ||
| question, | ||
| parentInteractionId, | ||
| verbose, | ||
| traceDisabled, | ||
| traceTensors, | ||
| conversationIndexMemory, | ||
| traceNumber, | ||
| additionalInfo, | ||
| response, | ||
| tools | ||
| ); | ||
| }, listener::onFailure); | ||
|
|
||
| generateLLMSummary( | ||
| traceTensors, | ||
| llmSpec, | ||
| tenantId, | ||
| ActionListener | ||
| .wrap( | ||
| summary -> responseListener | ||
| .onResponse(String.format(Locale.ROOT, MAX_ITERATIONS_SUMMARY_MESSAGE, maxIterations, summary)), | ||
| e -> { | ||
| log.error("Failed to generate LLM summary, using fallback strategy", e); | ||
| String fallbackResponse = (lastThought.get() != null | ||
| && !lastThought.get().isEmpty() | ||
| && !"null".equals(lastThought.get())) | ||
| ? String | ||
| .format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) | ||
| : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); | ||
| responseListener.onResponse(fallbackResponse); | ||
| } | ||
| ) | ||
| ); | ||
| } | ||
|
|
||
| private void sendTraditionalMaxIterationsResponse( | ||
| String sessionId, | ||
| ActionListener<Object> listener, | ||
| String question, | ||
| String parentInteractionId, | ||
| boolean verbose, | ||
| boolean traceDisabled, | ||
| List<ModelTensors> traceTensors, | ||
| ConversationIndexMemory conversationIndexMemory, | ||
| AtomicInteger traceNumber, | ||
| Map<String, Object> additionalInfo, | ||
| String response, | ||
| Map<String, Tool> tools | ||
| ) { | ||
| String incompleteResponse = (lastThought.get() != null && !lastThought.get().isEmpty() && !"null".equals(lastThought.get())) | ||
| ? String.format("%s. Last thought: %s", String.format(MAX_ITERATIONS_MESSAGE, maxIterations), lastThought.get()) | ||
| : String.format(MAX_ITERATIONS_MESSAGE, maxIterations); | ||
| sendFinalAnswer( | ||
| sessionId, | ||
| listener, | ||
|
|
@@ -903,11 +969,76 @@ private void handleMaxIterationsReached( | |
| conversationIndexMemory, | ||
| traceNumber, | ||
| additionalInfo, | ||
| incompleteResponse | ||
| response | ||
| ); | ||
| cleanUpResource(tools); | ||
| } | ||
|
|
||
| void generateLLMSummary(List<ModelTensors> stepsSummary, LLMSpec llmSpec, String tenantId, ActionListener<String> listener) { | ||
| if (stepsSummary == null || stepsSummary.isEmpty()) { | ||
| listener.onFailure(new IllegalArgumentException("Steps summary cannot be null or empty")); | ||
| return; | ||
| } | ||
|
|
||
| try { | ||
| Map<String, String> summaryParams = new HashMap<>(); | ||
| if (llmSpec.getParameters() != null) { | ||
| summaryParams.putAll(llmSpec.getParameters()); | ||
| } | ||
|
|
||
| // Convert ModelTensors to strings before joining | ||
| List<String> stepStrings = new ArrayList<>(); | ||
| for (ModelTensors tensor : stepsSummary) { | ||
| if (tensor != null && tensor.getMlModelTensors() != null) { | ||
| for (ModelTensor modelTensor : tensor.getMlModelTensors()) { | ||
| if (modelTensor.getResult() != null) { | ||
| stepStrings.add(modelTensor.getResult()); | ||
| } else if (modelTensor.getDataAsMap() != null && modelTensor.getDataAsMap().containsKey("response")) { | ||
| stepStrings.add(String.valueOf(modelTensor.getDataAsMap().get("response"))); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| String summaryPrompt = String.format(Locale.ROOT, SUMMARY_PROMPT_TEMPLATE, String.join("\n", stepStrings)); | ||
| summaryParams.put(PROMPT, summaryPrompt); | ||
| summaryParams.putIfAbsent(SYSTEM_PROMPT_FIELD, ""); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this being added as empty? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Hailong-am This change was made based on hailong’s previous comment. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you test this? remote model providers generally throw exception on empty system prompt lets add a simple default prompt instead of empty prompt? |
||
|
|
||
| ActionRequest request = new MLPredictionTaskRequest( | ||
| llmSpec.getModelId(), | ||
| RemoteInferenceMLInput | ||
| .builder() | ||
| .algorithm(FunctionName.REMOTE) | ||
| .inputDataset(RemoteInferenceInputDataSet.builder().parameters(summaryParams).build()) | ||
| .build(), | ||
| null, | ||
| tenantId | ||
| ); | ||
| client.execute(MLPredictionTaskAction.INSTANCE, request, ActionListener.wrap(response -> { | ||
| String summary = extractSummaryFromResponse(response); | ||
| if (summary == null) { | ||
| listener.onFailure(new RuntimeException("Empty or invalid LLM summary response")); | ||
| return; | ||
| } | ||
| listener.onResponse(summary); | ||
| }, listener::onFailure)); | ||
| } catch (Exception e) { | ||
| listener.onFailure(e); | ||
| } | ||
| } | ||
|
|
||
| public String extractSummaryFromResponse(MLTaskResponse response) { | ||
| try { | ||
| String outputString = outputToOutputString(response.getOutput()); | ||
| if (outputString != null && !outputString.trim().isEmpty()) { | ||
| return outputString.trim(); | ||
| } | ||
| return null; | ||
| } catch (Exception e) { | ||
| log.error("Failed to extract summary from response", e); | ||
| throw new RuntimeException("Failed to extract summary from response", e); | ||
| } | ||
| } | ||
|
|
||
| private void saveMessage( | ||
| ConversationIndexMemory memory, | ||
| String question, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.