diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index 9af058c734..52f5e48486 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -466,7 +466,7 @@ public void createConnector() { .name("test") .description("description") .version("testModelVersion") - .protocol("testProtocol") + .protocol("http") .parameters(params) .credential(credentials) .actions(null) diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index 1aa00c95d2..6e777c4cd8 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -1033,7 +1033,7 @@ public void createConnector() { .name("test") .description("description") .version("testModelVersion") - .protocol("testProtocol") + .protocol("http") .parameters(params) .credential(credentials) .actions(null) diff --git a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java index b66a23f11e..d1ef79f52a 100644 --- a/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java +++ b/common/src/main/java/org/opensearch/ml/common/agent/MLAgent.java @@ -13,6 +13,7 @@ import java.io.IOException; import java.time.Instant; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashSet; import java.util.List; import java.util.Locale; @@ -122,7 +123,7 @@ private void validate() { for (MLToolSpec toolSpec : tools) { String toolName = Optional.ofNullable(toolSpec.getName()).orElse(toolSpec.getType()); if (toolNames.contains(toolName)) { - throw new IllegalArgumentException("Duplicate tool defined: " + toolName); + throw new IllegalArgumentException("Duplicate tool defined in agent configuration"); } else { toolNames.add(toolName); } @@ -138,7 +139,7 @@ private void validateMLAgentType(String agentType) { MLAgentType.valueOf(agentType.toUpperCase(Locale.ROOT)); // Use toUpperCase() to allow case-insensitive matching } catch (IllegalArgumentException e) { // The typeStr does not match any MLAgentType, so throw a new exception with a clearer message. - throw new IllegalArgumentException(agentType + " is not a valid Agent Type"); + throw new IllegalArgumentException("Invalid Agent Type, Please use one of " + Arrays.toString(MLAgentType.values())); } } } diff --git a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java index f0d8cd656f..d138093ba0 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInput.java @@ -31,6 +31,7 @@ import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.common.connector.ConnectorAction; import org.opensearch.ml.common.connector.ConnectorClientConfig; +import org.opensearch.ml.common.connector.ConnectorProtocols; import lombok.Builder; import lombok.Data; @@ -121,6 +122,7 @@ public MLCreateConnectorInput( if ((url == null || url.isBlank()) && isMcpConnector) { throw new IllegalArgumentException("MCP Connector url is null or blank"); } + ConnectorProtocols.validateProtocol(protocol); } this.name = name; this.description = description; diff --git a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java index da2f5f5c1e..d272877597 100644 --- a/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java +++ b/common/src/test/java/org/opensearch/ml/common/agent/MLAgentTest.java @@ -9,6 +9,7 @@ import java.io.IOException; import java.time.Instant; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; @@ -114,7 +115,7 @@ public void constructor_NullLLMSpec() { @Test public void constructor_DuplicateTool() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage("Duplicate tool defined: test"); + exceptionRule.expectMessage("Duplicate tool defined in agent configuration"); MLAgent agent = new MLAgent( "test_name", @@ -353,7 +354,7 @@ public void fromStream() throws IOException { @Test public void constructor_InvalidAgentType() { exceptionRule.expect(IllegalArgumentException.class); - exceptionRule.expectMessage(" is not a valid Agent Type"); + exceptionRule.expectMessage("Invalid Agent Type, Please use one of " + Arrays.toString(MLAgentType.values())); new MLAgent( "test_name", diff --git a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java index a7df00618a..bac44679fc 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/connector/MLCreateConnectorInputTests.java @@ -10,8 +10,7 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; -import static org.opensearch.ml.common.connector.ConnectorProtocols.MCP_SSE; -import static org.opensearch.ml.common.connector.ConnectorProtocols.MCP_STREAMABLE_HTTP; +import static org.opensearch.ml.common.connector.ConnectorProtocols.*; import java.io.IOException; import java.util.Arrays; @@ -169,6 +168,29 @@ public void constructorMLCreateConnectorInput_NullProtocol() { assertEquals("Connector protocol is null", exception.getMessage()); } + @Test + public void constructorMLCreateConnectorInput_InvalidProtocol() { + Throwable exception = assertThrows(IllegalArgumentException.class, () -> { + MLCreateConnectorInput + .builder() + .name(TEST_CONNECTOR_NAME) + .description(TEST_CONNECTOR_DESCRIPTION) + .version(TEST_CONNECTOR_VERSION) + .protocol("dummy") + .parameters(Map.of(TEST_PARAM_KEY, TEST_PARAM_VALUE)) + .credential(Map.of(TEST_CREDENTIAL_KEY, TEST_CREDENTIAL_VALUE)) + .actions(List.of()) + .access(AccessMode.PUBLIC) + .backendRoles(Arrays.asList(TEST_ROLE1, TEST_ROLE2)) + .addAllBackendRoles(false) + .build(); + }); + assertEquals( + "Unsupported connector protocol. Please use one of " + Arrays.toString(VALID_PROTOCOLS.toArray(new String[0])), + exception.getMessage() + ); + } + @Test public void constructorMLCreateConnectorInput_NullCredential() { Throwable exception = assertThrows(IllegalArgumentException.class, () -> { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java index 7de547127a..42af6dad29 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/AgentUtils.java @@ -993,7 +993,7 @@ public static List getToolNames(Map tools) { public static Tool createTool(Map toolFactories, Map executeParams, MLToolSpec toolSpec) { if (!toolFactories.containsKey(toolSpec.getType())) { - throw new IllegalArgumentException("Tool not found: " + toolSpec.getType()); + throw new IllegalArgumentException("Tool type not found"); } Map toolParams = new HashMap<>(); toolParams.putAll(executeParams); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java index 1594506cf4..bd8c5145eb 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/agent/MLAgentExecutor.java @@ -21,12 +21,7 @@ import java.security.PrivilegedActionException; import java.security.PrivilegedExceptionAction; import java.time.Instant; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Objects; +import java.util.*; import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchException; @@ -654,7 +649,7 @@ protected MLAgentRunner getAgentRunner(MLAgent mlAgent) { encryptor ); default: - throw new IllegalArgumentException("Unsupported agent type: " + mlAgent.getType()); + throw new IllegalArgumentException("Unsupported agent type. Please use one of " + Arrays.toString(MLAgentType.values())); } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/FunctionCallingFactory.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/FunctionCallingFactory.java index 5a38d0b7bd..66fac6a609 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/FunctionCallingFactory.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/function_calling/FunctionCallingFactory.java @@ -27,7 +27,15 @@ public static FunctionCalling create(String llmInterface) { case LLM_INTERFACE_BEDROCK_CONVERSE_DEEPSEEK_R1: return new BedrockConverseDeepseekR1FunctionCalling(); default: - throw new IllegalArgumentException(String.format("Invalid _llm_interface: %s", llmInterface)); + throw new IllegalArgumentException( + String + .format( + "Invalid _llm_interface. Supported values are %s,%s,%s", + LLM_INTERFACE_BEDROCK_CONVERSE_CLAUDE, + LLM_INTERFACE_OPENAI_V1_CHAT_COMPLETIONS, + LLM_INTERFACE_BEDROCK_CONVERSE_DEEPSEEK_R1 + ) + ); } } } diff --git a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java index e91d27c9bb..0680cc0bfd 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/agents/TransportRegisterAgentAction.java @@ -104,7 +104,7 @@ private void registerAgent(MLAgent agent, ActionListener !buildInToolNames.contains(type)) .collect(Collectors.toSet()); if (!unrecognizedTools.isEmpty()) { - exception.addValidationError(String.format(Locale.ROOT, "Unrecognized tool in request: %s", unrecognizedTools)); + exception.addValidationError("Unrecognized tool in request"); throw exception; } return channel -> client.execute(MLMcpToolsRegisterAction.INSTANCE, registerNodesRequest, new RestToXContentListener<>(channel));