Skip to content

Commit c4537fd

Browse files
xinyualbrianf-aws
authored andcommitted
Fix claude model it (opensearch-project#4167)
* fix model it by replace claude v1/v2 Signed-off-by: xinyual <[email protected]> * remove useless change Signed-off-by: xinyual <[email protected]> --------- Signed-off-by: xinyual <[email protected]> Signed-off-by: Brian Flores <[email protected]>
1 parent 14dc66f commit c4537fd

File tree

3 files changed

+140
-42
lines changed

3 files changed

+140
-42
lines changed

plugin/src/test/java/org/opensearch/ml/rest/RestConnectorToolIT.java

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ public class RestConnectorToolIT extends RestBaseAgentToolsIT {
1919
private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID");
2020
private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY");
2121
private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN");
22+
2223
private static final String GITHUB_CI_AWS_REGION = "us-west-2";
24+
private static final String BEDROCK_ANTHROPIC_CLAUDE_3_5_SONNET = "anthropic.claude-3-5-sonnet-20240620-v1:0";
2325

2426
private String bedrockClaudeConnectorId;
2527
private String bedrockClaudeConnectorIdForPredict;
@@ -35,19 +37,20 @@ public void setUp() throws Exception {
3537

3638
private String createBedrockClaudeConnector(String action) throws IOException, InterruptedException {
3739
String bedrockClaudeConnectorEntity = "{\n"
38-
+ " \"name\": \"BedRock Claude instant-v1 Connector \",\n"
39-
+ " \"description\": \"The connector to BedRock service for claude model\",\n"
40+
+ " \"name\": \"Bedrock Connector: claude 3.5\",\n"
41+
+ " \"description\": \"The connector to bedrock claude 3.5 model\",\n"
4042
+ " \"version\": 1,\n"
4143
+ " \"protocol\": \"aws_sigv4\",\n"
4244
+ " \"parameters\": {\n"
4345
+ " \"region\": \""
4446
+ GITHUB_CI_AWS_REGION
4547
+ "\",\n"
4648
+ " \"service_name\": \"bedrock\",\n"
47-
+ " \"anthropic_version\": \"bedrock-2023-05-31\",\n"
48-
+ " \"max_tokens_to_sample\": 8000,\n"
49-
+ " \"temperature\": 0.0001,\n"
50-
+ " \"response_filter\": \"$.completion\"\n"
49+
+ " \"model\": \""
50+
+ BEDROCK_ANTHROPIC_CLAUDE_3_5_SONNET
51+
+ "\",\n"
52+
+ " \"system_prompt\": \"You are a helpful assistant.\",\n"
53+
+ "\"response_filter\": \"$.output.message.content[0].text\""
5154
+ " },\n"
5255
+ " \"credential\": {\n"
5356
+ " \"access_key\": \""
@@ -61,19 +64,22 @@ private String createBedrockClaudeConnector(String action) throws IOException, I
6164
+ "\"\n"
6265
+ " },\n"
6366
+ " \"actions\": [\n"
64-
+ " {\n"
65-
+ " \"action_type\": \""
67+
+ " {\n"
68+
+ " \"action_type\": \""
6669
+ action
6770
+ "\",\n"
68-
+ " \"method\": \"POST\",\n"
69-
+ " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/anthropic.claude-instant-v1/invoke\",\n"
70-
+ " \"headers\": {\n"
71-
+ " \"content-type\": \"application/json\",\n"
72-
+ " \"x-amz-content-sha256\": \"required\"\n"
73-
+ " },\n"
74-
+ " \"request_body\": \"{\\\"prompt\\\":\\\"\\\\n\\\\nHuman:${parameters.question}\\\\n\\\\nAssistant:\\\", \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n"
75-
+ " }\n"
76-
+ " ]\n"
71+
+ " \"method\": \"POST\",\n"
72+
+ " \"headers\": {\n"
73+
+ " \"content-type\": \"application/json\"\n"
74+
+ " },\n"
75+
+ " \"url\": \"https://bedrock-runtime."
76+
+ GITHUB_CI_AWS_REGION
77+
+ ".amazonaws.com/model/"
78+
+ BEDROCK_ANTHROPIC_CLAUDE_3_5_SONNET
79+
+ "/converse\",\n"
80+
+ " \"request_body\": \"{ \\\"system\\\": [{\\\"text\\\": \\\"you are a helpful assistant.\\\"}], \\\"messages\\\":[{\\\"role\\\": \\\"user\\\", \\\"content\\\":[ {\\\"type\\\": \\\"text\\\", \\\"text\\\":\\\"${parameters.messages}\\\"}]}] , \\\"inferenceConfig\\\": {\\\"temperature\\\": 0.0, \\\"topP\\\": 0.9, \\\"maxTokens\\\": 1000} }\"\n"
81+
+ " }\n"
82+
+ " ]\n"
7783
+ "}";
7884
return registerConnector(bedrockClaudeConnectorEntity);
7985
}
@@ -135,7 +141,7 @@ public void testConnectorToolInFlowAgent() throws IOException {
135141
+ " ]\n"
136142
+ "}";
137143
String agentId = createAgent(registerAgentRequestBody);
138-
String agentInput = "{\n" + " \"parameters\": {\n" + " \"question\": \"hello\"\n" + " }\n" + "}";
144+
String agentInput = "{\n" + " \"parameters\": {\n" + " \"messages\": \"hello\"\n" + " }\n" + "}";
139145
String result = executeAgent(agentId, agentInput);
140146
assertNotNull(result);
141147
}

plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ public class RestMLInferenceSearchResponseProcessorIT extends MLCommonsRestTestC
6767
private static final String AWS_ACCESS_KEY_ID = System.getenv("AWS_ACCESS_KEY_ID");
6868
private static final String AWS_SECRET_ACCESS_KEY = System.getenv("AWS_SECRET_ACCESS_KEY");
6969
private static final String AWS_SESSION_TOKEN = System.getenv("AWS_SESSION_TOKEN");
70+
7071
private static final String GITHUB_CI_AWS_REGION = "us-west-2";
7172

7273
private final String bedrockEmbeddingModelConnectorEntity = "{\n"
@@ -109,20 +110,20 @@ public class RestMLInferenceSearchResponseProcessorIT extends MLCommonsRestTestC
109110
+ "}";
110111

111112
private final String bedrockClaudeModelConnectorEntity = "{\n"
112-
+ " \"name\": \"BedRock Claude instant-v1 Connector\",\n"
113-
+ " \"description\": \"The connector to bedrock for claude model\",\n"
113+
+ " \"name\": \"Bedrock Connector: claude 3.5\",\n"
114+
+ " \"description\": \"The connector to bedrock claude 3.5 model\",\n"
114115
+ " \"version\": 1,\n"
115116
+ " \"protocol\": \"aws_sigv4\",\n"
116117
+ " \"parameters\": {\n"
117118
+ " \"region\": \""
118119
+ GITHUB_CI_AWS_REGION
119120
+ "\",\n"
120121
+ " \"service_name\": \"bedrock\",\n"
121-
+ " \"anthropic_version\": \"bedrock-2023-05-31\",\n"
122-
+ " \"max_tokens_to_sample\": 8000,\n"
123-
+ " \"temperature\": 0.0001,\n"
124-
+ " \"response_filter\": \"$.completion\",\n"
125-
+ " \"stop_sequences\": [\"\\n\\nHuman:\",\"\\nObservation:\",\"\\n\\tObservation:\",\"\\nObservation\",\"\\n\\tObservation\",\"\\n\\nQuestion\"]\n"
122+
+ " \"model\": \""
123+
+ "anthropic.claude-3-5-sonnet-20240620-v1:0"
124+
+ "\",\n"
125+
+ " \"system_prompt\": \"You are a helpful assistant.\",\n"
126+
+ "\"response_filter\": \"$.output.message.content[0].text\""
126127
+ " },\n"
127128
+ " \"credential\": {\n"
128129
+ " \"access_key\": \""
@@ -136,17 +137,22 @@ public class RestMLInferenceSearchResponseProcessorIT extends MLCommonsRestTestC
136137
+ "\"\n"
137138
+ " },\n"
138139
+ " \"actions\": [\n"
139-
+ " {\n"
140-
+ " \"action_type\": \"predict\",\n"
141-
+ " \"method\": \"POST\",\n"
142-
+ " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/anthropic.claude-instant-v1/invoke\",\n"
143-
+ " \"headers\": {\n"
144-
+ " \"content-type\": \"application/json\",\n"
145-
+ " \"x-amz-content-sha256\": \"required\"\n"
146-
+ " },\n"
147-
+ " \"request_body\": \"{\\\"prompt\\\":\\\"${parameters.prompt}\\\", \\\"stop_sequences\\\": ${parameters.stop_sequences}, \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n"
148-
+ " }\n"
149-
+ " ]\n"
140+
+ " {\n"
141+
+ " \"action_type\": \""
142+
+ "predict"
143+
+ "\",\n"
144+
+ " \"method\": \"POST\",\n"
145+
+ " \"headers\": {\n"
146+
+ " \"content-type\": \"application/json\"\n"
147+
+ " },\n"
148+
+ " \"url\": \"https://bedrock-runtime."
149+
+ GITHUB_CI_AWS_REGION
150+
+ ".amazonaws.com/model/"
151+
+ "anthropic.claude-3-5-sonnet-20240620-v1:0"
152+
+ "/converse\",\n"
153+
+ " \"request_body\": \"{ \\\"system\\\": [{\\\"text\\\": \\\"you are a helpful assistant.\\\"}], \\\"messages\\\":[{\\\"role\\\": \\\"user\\\", \\\"content\\\":[ {\\\"type\\\": \\\"text\\\", \\\"text\\\":\\\"${parameters.prompt}\\\"}]}] , \\\"inferenceConfig\\\": {\\\"temperature\\\": 0.0, \\\"topP\\\": 0.9, \\\"maxTokens\\\": 1000} }\"\n"
154+
+ " }\n"
155+
+ " ]\n"
150156
+ "}";
151157

152158
private final String bedrockMultiModalEmbeddingModelConnectorEntity = "{\n"

plugin/src/test/java/org/opensearch/ml/rest/RestMLRAGSearchProcessorIT.java

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,52 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
111111
private static final String BEDROCK_ANTHROPIC_CLAUDE_3_5_SONNET = "anthropic.claude-3-5-sonnet-20240620-v1:0";
112112
private static final String BEDROCK_ANTHROPIC_CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0";
113113

114+
private static final String BEDROCK_CONNECTOR_BLUEPRINT_INVOKE = "{\n"
115+
+ " \"name\": \"Bedrock Connector: claude 3.5\",\n"
116+
+ " \"description\": \"The connector to bedrock claude 3.5 model\",\n"
117+
+ " \"version\": 1,\n"
118+
+ " \"protocol\": \"aws_sigv4\",\n"
119+
+ " \"parameters\": {\n"
120+
+ " \"region\": \""
121+
+ GITHUB_CI_AWS_REGION
122+
+ "\",\n"
123+
+ " \"service_name\": \"bedrock\",\n"
124+
+ " \"model\": \""
125+
+ "anthropic.claude-3-5-sonnet-20240620-v1:0"
126+
+ "\",\n"
127+
+ " \"system_prompt\": \"You are a helpful assistant.\",\n"
128+
+ "\"response_filter\": \"$.content[0].text\""
129+
+ " },\n"
130+
+ " \"credential\": {\n"
131+
+ " \"access_key\": \""
132+
+ AWS_ACCESS_KEY_ID
133+
+ "\",\n"
134+
+ " \"secret_key\": \""
135+
+ AWS_SECRET_ACCESS_KEY
136+
+ "\",\n"
137+
+ " \"session_token\": \""
138+
+ AWS_SESSION_TOKEN
139+
+ "\"\n"
140+
+ " },\n"
141+
+ " \"actions\": [\n"
142+
+ " {\n"
143+
+ " \"action_type\": \""
144+
+ "predict"
145+
+ "\",\n"
146+
+ " \"method\": \"POST\",\n"
147+
+ " \"headers\": {\n"
148+
+ " \"content-type\": \"application/json\"\n"
149+
+ " },\n"
150+
+ " \"url\": \"https://bedrock-runtime."
151+
+ GITHUB_CI_AWS_REGION
152+
+ ".amazonaws.com/model/"
153+
+ "anthropic.claude-3-5-sonnet-20240620-v1:0"
154+
+ "/invoke\",\n"
155+
+ " \"request_body\": \"{\\\"messages\\\":[{\\\"role\\\": \\\"user\\\", \\\"content\\\":[ {\\\"type\\\": \\\"text\\\", \\\"text\\\":\\\"${parameters.inputs}\\\"}]}], \\\"max_tokens\\\":300, \\\"temperature\\\":0.5, \\\"anthropic_version\\\":\\\"bedrock-2023-05-31\\\" }\"\n"
156+
+ " }\n"
157+
+ " ]\n"
158+
+ "}";
159+
114160
private static final String BEDROCK_CONNECTOR_BLUEPRINT1 = "{\n"
115161
+ " \"name\": \"Bedrock Connector: claude2\",\n"
116162
+ " \"description\": \"The connector to bedrock claude2 model\",\n"
@@ -181,7 +227,7 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
181227
+ " ]\n"
182228
+ "}";
183229

184-
private static final String BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2 = "{\n"
230+
static final String BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2 = "{\n"
185231
+ " \"name\": \"Bedrock Connector: claude 3.5\",\n"
186232
+ " \"description\": \"The connector to bedrock claude 3.5 model\",\n"
187233
+ " \"version\": 1,\n"
@@ -268,8 +314,8 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
268314
+ "}";
269315

270316
private static final String BEDROCK_CONNECTOR_BLUEPRINT = AWS_SESSION_TOKEN == null
271-
? BEDROCK_CONNECTOR_BLUEPRINT2
272-
: BEDROCK_CONNECTOR_BLUEPRINT1;
317+
? BEDROCK_CONNECTOR_BLUEPRINT_INVOKE
318+
: BEDROCK_CONNECTOR_BLUEPRINT_INVOKE;
273319

274320
private static final String BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT = AWS_SESSION_TOKEN == null
275321
? BEDROCK_CONVERSE_CONNECTOR_BLUEPRINT2
@@ -425,6 +471,26 @@ public class RestMLRAGSearchProcessorIT extends MLCommonsRestTestCase {
425471
+ " }\n"
426472
+ "}";
427473

474+
private static final String BM25_SEARCH_REQUEST_WITH_CONVO_WITH_LLM_RESPONSE_TEMPLATE = "{\n"
475+
+ " \"_source\": [\"%s\"],\n"
476+
+ " \"query\" : {\n"
477+
+ " \"match\": {\"%s\": \"%s\"}\n"
478+
+ " },\n"
479+
+ " \"ext\": {\n"
480+
+ " \"generative_qa_parameters\": {\n"
481+
+ " \"llm_model\": \"%s\",\n"
482+
+ " \"llm_question\": \"%s\",\n"
483+
+ " \"memory_id\": \"%s\",\n"
484+
+ " \"system_prompt\": \"%s\",\n"
485+
+ " \"user_instructions\": \"%s\",\n"
486+
+ " \"context_size\": %d,\n"
487+
+ " \"message_size\": %d,\n"
488+
+ " \"timeout\": %d,\n"
489+
+ " \"llm_response_field\": \"%s\"\n"
490+
+ " }\n"
491+
+ " }\n"
492+
+ "}";
493+
428494
private static final String BM25_SEARCH_REQUEST_WITH_CONVO_AND_IMAGE_TEMPLATE = "{\n"
429495
+ " \"_source\": [\"%s\"],\n"
430496
+ " \"query\" : {\n"
@@ -705,6 +771,7 @@ public void testBM25WithBedrock() throws Exception {
705771
requestParameters.contextSize = 5;
706772
requestParameters.interactionSize = 5;
707773
requestParameters.timeout = 60;
774+
requestParameters.llmResponseField = "response";
708775
Response response2 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters);
709776
assertEquals(200, response2.getStatusLine().getStatusCode());
710777

@@ -1068,6 +1135,7 @@ public void testBM25WithBedrockWithConversation() throws Exception {
10681135
requestParameters.interactionSize = 5;
10691136
requestParameters.timeout = 60;
10701137
requestParameters.conversationId = conversationId;
1138+
requestParameters.llmResponseField = "response";
10711139
Response response2 = performSearch(INDEX_NAME, "pipeline_test", 5, requestParameters);
10721140
assertEquals(200, response2.getStatusLine().getStatusCode());
10731141

@@ -1240,7 +1308,7 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
12401308
throws Exception {
12411309

12421310
// TODO build these templates dynamically
1243-
String httpEntity = requestParameters.llmResponseField != null
1311+
String httpEntity = requestParameters.llmResponseField != null && requestParameters.conversationId == null
12441312
? String
12451313
.format(
12461314
Locale.ROOT,
@@ -1351,10 +1419,27 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
13511419
requestParameters.interactionSize,
13521420
requestParameters.timeout
13531421
)
1422+
: (requestParameters.llmResponseField == null)
1423+
? String
1424+
.format(
1425+
Locale.ROOT,
1426+
BM25_SEARCH_REQUEST_WITH_CONVO_TEMPLATE,
1427+
requestParameters.source,
1428+
requestParameters.source,
1429+
requestParameters.match,
1430+
requestParameters.llmModel,
1431+
requestParameters.llmQuestion,
1432+
requestParameters.conversationId,
1433+
requestParameters.systemPrompt,
1434+
requestParameters.userInstructions,
1435+
requestParameters.contextSize,
1436+
requestParameters.interactionSize,
1437+
requestParameters.timeout
1438+
)
13541439
: String
13551440
.format(
13561441
Locale.ROOT,
1357-
BM25_SEARCH_REQUEST_WITH_CONVO_TEMPLATE,
1442+
BM25_SEARCH_REQUEST_WITH_CONVO_WITH_LLM_RESPONSE_TEMPLATE,
13581443
requestParameters.source,
13591444
requestParameters.source,
13601445
requestParameters.match,
@@ -1365,7 +1450,8 @@ private Response performSearch(String indexName, String pipeline, int size, Sear
13651450
requestParameters.userInstructions,
13661451
requestParameters.contextSize,
13671452
requestParameters.interactionSize,
1368-
requestParameters.timeout
1453+
requestParameters.timeout,
1454+
requestParameters.llmResponseField
13691455
);
13701456
return makeRequest(
13711457
client(),

0 commit comments

Comments
 (0)