Skip to content

Commit 5883f54

Browse files
authored
Move HttpClientFactory to common to expose to other components (#4175)
* Move HttpClientFactory to common to expose to other componenets Signed-off-by: zane-neo <[email protected]> * optimize code for better maintainability Signed-off-by: zane-neo <[email protected]> * Optimize code and increase UT coverage Signed-off-by: zane-neo <[email protected]> * Address comments Signed-off-by: zane-neo <[email protected]> * Use amazon aws version from opensearch core Signed-off-by: zane-neo <[email protected]> * address comments Signed-off-by: zane-neo <[email protected]> --------- Signed-off-by: zane-neo <[email protected]>
1 parent e0e64ad commit 5883f54

File tree

7 files changed

+84
-60
lines changed

7 files changed

+84
-60
lines changed

common/build.gradle

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ dependencies {
4343
compileOnly group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0'
4444
// Multi-tenant SDK Client
4545
compileOnly "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}"
46+
compileOnly (group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "${versions.aws}") {
47+
exclude(group: 'org.reactivestreams', module: 'reactive-streams')
48+
exclude(group: 'org.slf4j', module: 'slf4j-api')
49+
}
4650
}
4751

4852
lombok {
Lines changed: 25 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,18 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.engine.httpclient;
6+
package org.opensearch.ml.common.httpclient;
77

88
import java.net.Inet4Address;
99
import java.net.InetAddress;
1010
import java.net.UnknownHostException;
11-
import java.security.AccessController;
12-
import java.security.PrivilegedActionException;
13-
import java.security.PrivilegedExceptionAction;
1411
import java.time.Duration;
1512
import java.util.Arrays;
1613
import java.util.Locale;
1714
import java.util.concurrent.atomic.AtomicBoolean;
1815

16+
import org.opensearch.common.util.concurrent.ThreadContextAccess;
17+
1918
import lombok.extern.log4j.Log4j2;
2019
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
2120
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
@@ -24,19 +23,15 @@
2423
public class MLHttpClientFactory {
2524

2625
public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout, Duration readTimeout, int maxConnections) {
27-
try {
28-
return AccessController
29-
.doPrivileged(
30-
(PrivilegedExceptionAction<SdkAsyncHttpClient>) () -> NettyNioAsyncHttpClient
31-
.builder()
32-
.connectionTimeout(connectionTimeout)
33-
.readTimeout(readTimeout)
34-
.maxConcurrency(maxConnections)
35-
.build()
36-
);
37-
} catch (PrivilegedActionException e) {
38-
return null;
39-
}
26+
return ThreadContextAccess
27+
.doPrivileged(
28+
() -> NettyNioAsyncHttpClient
29+
.builder()
30+
.connectionTimeout(connectionTimeout)
31+
.readTimeout(readTimeout)
32+
.maxConcurrency(maxConnections)
33+
.build()
34+
);
4035
}
4136

4237
/**
@@ -50,7 +45,7 @@ public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout,
5045
public static void validate(String protocol, String host, int port, AtomicBoolean connectorPrivateIpEnabled)
5146
throws UnknownHostException {
5247
if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) {
53-
log.error("Remote inference protocol is not http or https: " + protocol);
48+
log.error("Remote inference protocol is not http or https: {}", protocol);
5449
throw new IllegalArgumentException("Protocol is not http or https: " + protocol);
5550
}
5651
// When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol.
@@ -62,7 +57,7 @@ public static void validate(String protocol, String host, int port, AtomicBoolea
6257
}
6358
}
6459
if (port < 0 || port > 65536) {
65-
log.error("Remote inference port out of range: " + port);
60+
log.error("Remote inference port out of range: {}", port);
6661
throw new IllegalArgumentException("Port out of range: " + port);
6762
}
6863
validateIp(host, connectorPrivateIpEnabled);
@@ -71,7 +66,7 @@ public static void validate(String protocol, String host, int port, AtomicBoolea
7166
private static void validateIp(String hostName, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException {
7267
InetAddress[] addresses = InetAddress.getAllByName(hostName);
7368
if ((connectorPrivateIpEnabled == null || !connectorPrivateIpEnabled.get()) && hasPrivateIpAddress(addresses)) {
74-
log.error("Remote inference host name has private ip address: " + hostName);
69+
log.error("Remote inference host name has private ip address: {}", hostName);
7570
throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName);
7671
}
7772
}
@@ -83,35 +78,23 @@ private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) {
8378
if (bytes.length != 4) {
8479
return true;
8580
} else {
86-
int firstOctets = bytes[0] & 0xff;
87-
int firstInOctal = parseWithOctal(String.valueOf(firstOctets));
88-
int firstInHex = Integer.parseInt(String.valueOf(firstOctets), 16);
89-
if (firstInOctal == 127 || firstInHex == 127) {
90-
return bytes[1] == 0 && bytes[2] == 0 && bytes[3] == 1;
91-
} else if (firstInOctal == 10 || firstInHex == 10) {
81+
if (isPrivateIPv4(bytes)) {
9282
return true;
93-
} else if (firstInOctal == 172 || firstInHex == 172) {
94-
int secondOctets = bytes[1] & 0xff;
95-
int secondInOctal = parseWithOctal(String.valueOf(secondOctets));
96-
int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16);
97-
return (secondInOctal >= 16 && secondInOctal <= 32) || (secondInHex >= 16 && secondInHex <= 32);
98-
} else if (firstInOctal == 192 || firstInHex == 192) {
99-
int secondOctets = bytes[1] & 0xff;
100-
int secondInOctal = parseWithOctal(String.valueOf(secondOctets));
101-
int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16);
102-
return secondInOctal == 168 || secondInHex == 168;
10383
}
10484
}
10585
}
10686
}
10787
return Arrays.stream(ipAddress).anyMatch(x -> x.isSiteLocalAddress() || x.isLoopbackAddress() || x.isAnyLocalAddress());
10888
}
10989

110-
private static int parseWithOctal(String input) {
111-
try {
112-
return Integer.parseInt(input, 8);
113-
} catch (NumberFormatException e) {
114-
return Integer.parseInt(input);
115-
}
90+
private static boolean isPrivateIPv4(byte[] bytes) {
91+
int first = bytes[0] & 0xff;
92+
int second = bytes[1] & 0xff;
93+
94+
// 127.0.0.1, 10.x.x.x, 172.16-31.x.x, 192.168.x.x, 169.254.x.x
95+
return (first == 10)
96+
|| (first == 172 && second >= 16 && second <= 31)
97+
|| (first == 192 && second == 168)
98+
|| (first == 169 && second == 254);
11699
}
117100
}
Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.ml.engine.httpclient;
6+
package org.opensearch.ml.common.httpclient;
77

8+
import static org.junit.Assert.assertEquals;
89
import static org.junit.Assert.assertNotNull;
10+
import static org.junit.Assert.assertThrows;
911

1012
import java.time.Duration;
1113
import java.util.concurrent.atomic.AtomicBoolean;
@@ -33,6 +35,39 @@ public void test_getSdkAsyncHttpClient_success() {
3335
assertNotNull(client);
3436
}
3537

38+
@Test
39+
public void test_invalidIP_localHost_privateIPDisabled() {
40+
IllegalArgumentException e1 = assertThrows(
41+
IllegalArgumentException.class,
42+
() -> MLHttpClientFactory.validate(HTTP, "127.0.0.1", 80, PRIVATE_IP_DISABLED)
43+
);
44+
assertEquals("Remote inference host name has private ip address: 127.0.0.1", e1.getMessage());
45+
46+
IllegalArgumentException e2 = assertThrows(
47+
IllegalArgumentException.class,
48+
() -> MLHttpClientFactory.validate(HTTP, "192.168.0.1", 80, PRIVATE_IP_DISABLED)
49+
);
50+
assertEquals("Remote inference host name has private ip address: 192.168.0.1", e2.getMessage());
51+
52+
IllegalArgumentException e3 = assertThrows(
53+
IllegalArgumentException.class,
54+
() -> MLHttpClientFactory.validate(HTTP, "169.254.0.1", 80, PRIVATE_IP_DISABLED)
55+
);
56+
assertEquals("Remote inference host name has private ip address: 169.254.0.1", e3.getMessage());
57+
58+
IllegalArgumentException e4 = assertThrows(
59+
IllegalArgumentException.class,
60+
() -> MLHttpClientFactory.validate(HTTP, "172.16.0.1", 80, PRIVATE_IP_DISABLED)
61+
);
62+
assertEquals("Remote inference host name has private ip address: 172.16.0.1", e4.getMessage());
63+
64+
IllegalArgumentException e5 = assertThrows(
65+
IllegalArgumentException.class,
66+
() -> MLHttpClientFactory.validate(HTTP, "172.31.0.1", 80, PRIVATE_IP_DISABLED)
67+
);
68+
assertEquals("Remote inference host name has private ip address: 172.31.0.1", e5.getMessage());
69+
}
70+
3671
@Test
3772
public void test_validateIp_validIp_noException() throws Exception {
3873
MLHttpClientFactory.validate(HTTP, TEST_HOST, 80, PRIVATE_IP_DISABLED);
@@ -43,6 +78,8 @@ public void test_validateIp_validIp_noException() throws Exception {
4378
MLHttpClientFactory.validate(HTTP, "177.0.1.1", 80, PRIVATE_IP_DISABLED);
4479
MLHttpClientFactory.validate(HTTP, "177.0.0.2", 80, PRIVATE_IP_DISABLED);
4580
MLHttpClientFactory.validate(HTTP, "::ffff", 80, PRIVATE_IP_DISABLED);
81+
MLHttpClientFactory.validate(HTTP, "172.32.0.1", 80, PRIVATE_IP_ENABLED);
82+
MLHttpClientFactory.validate(HTTP, "172.2097152", 80, PRIVATE_IP_ENABLED);
4683
}
4784

4885
@Test

ml-algorithms/build.gradle

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,27 +68,27 @@ dependencies {
6868
}
6969
}
7070

71-
implementation platform('software.amazon.awssdk:bom:2.30.18')
72-
api 'software.amazon.awssdk:auth:2.30.18'
73-
implementation 'software.amazon.awssdk:apache-client'
74-
implementation ('software.amazon.awssdk:bedrockruntime') {
71+
implementation platform(group: 'software.amazon.awssdk', name: 'bom', version:"${versions.aws}")
72+
api 'software.amazon.awssdk:auth:${versions.aws}'
73+
implementation group: 'software.amazon.awssdk', name:'apache-client', version: "${versions.aws}"
74+
implementation (group: 'software.amazon.awssdk', name: 'bedrockruntime', version: "${versions.aws}") {
7575
exclude group: 'io.netty'
7676
}
7777
implementation ('com.amazonaws:aws-encryption-sdk-java:2.4.1') {
7878
exclude group: 'org.bouncycastle', module: 'bcprov-ext-jdk18on'
7979
}
8080
// needed by aws-encryption-sdk-java
8181
implementation "org.bouncycastle:bc-fips:${versions.bouncycastle_jce}"
82-
compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: "2.30.18"
83-
compileOnly group: 'software.amazon.awssdk', name: 's3', version: "2.30.18"
84-
compileOnly group: 'software.amazon.awssdk', name: 'regions', version: "2.30.18"
82+
compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: "${versions.aws}"
83+
compileOnly group: 'software.amazon.awssdk', name: 's3', version: "${versions.aws}"
84+
compileOnly group: 'software.amazon.awssdk', name: 'regions', version: "${versions.aws}"
8585

8686
implementation ('com.jayway.jsonpath:json-path:2.9.0') {
8787
exclude group: 'net.minidev', module: 'json-smart'
8888
}
8989
implementation('net.minidev:json-smart:2.5.2')
9090
implementation group: 'org.json', name: 'json', version: '20231013'
91-
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "2.30.18"
91+
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "${versions.aws}"
9292
api('io.modelcontextprotocol.sdk:mcp:0.12.1')
9393
testImplementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}")
9494
testImplementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}")
@@ -104,7 +104,7 @@ lombok {
104104
configurations.all {
105105
resolutionStrategy.force 'com.google.protobuf:protobuf-java:3.25.5'
106106
resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0'
107-
resolutionStrategy.force 'software.amazon.awssdk:bom:2.30.18'
107+
resolutionStrategy.force group: 'software.amazon.awssdk', name:'bom', version:"${versions.aws}"
108108
}
109109

110110
jacocoTestReport {

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828
import org.opensearch.ml.common.connector.AwsConnector;
2929
import org.opensearch.ml.common.connector.Connector;
3030
import org.opensearch.ml.common.exception.MLException;
31+
import org.opensearch.ml.common.httpclient.MLHttpClientFactory;
3132
import org.opensearch.ml.common.input.MLInput;
3233
import org.opensearch.ml.common.model.MLGuard;
3334
import org.opensearch.ml.common.output.model.ModelTensors;
3435
import org.opensearch.ml.common.transport.MLTaskResponse;
3536
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
36-
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
3737
import org.opensearch.script.ScriptService;
3838
import org.opensearch.transport.StreamTransportService;
3939
import org.opensearch.transport.client.Client;

ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@
3232
import org.opensearch.ml.common.connector.Connector;
3333
import org.opensearch.ml.common.connector.HttpConnector;
3434
import org.opensearch.ml.common.exception.MLException;
35+
import org.opensearch.ml.common.httpclient.MLHttpClientFactory;
3536
import org.opensearch.ml.common.input.MLInput;
3637
import org.opensearch.ml.common.model.MLGuard;
3738
import org.opensearch.ml.common.output.model.ModelTensors;
3839
import org.opensearch.ml.common.transport.MLTaskResponse;
3940
import org.opensearch.ml.common.utils.StringUtils;
4041
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
41-
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
4242
import org.opensearch.script.ScriptService;
4343
import org.opensearch.transport.StreamTransportService;
4444
import org.opensearch.transport.client.Client;

plugin/build.gradle

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,15 @@ dependencies {
5858
implementation project(':opensearch-ml-search-processors')
5959
implementation project(':opensearch-ml-memory')
6060

61-
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: "2.30.18"
62-
implementation group: 'software.amazon.awssdk', name: 's3', version: "2.30.18"
63-
implementation group: 'software.amazon.awssdk', name: 'regions', version: "2.30.18"
61+
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: "${versions.aws}"
62+
implementation group: 'software.amazon.awssdk', name: 's3', version: "${versions.aws}"
63+
implementation group: 'software.amazon.awssdk', name: 'regions', version: "${versions.aws}"
6464

65-
implementation group: 'software.amazon.awssdk', name: 'aws-xml-protocol', version: "2.30.18"
65+
implementation group: 'software.amazon.awssdk', name: 'aws-xml-protocol', version: "${versions.aws}"
6666

67-
implementation group: 'software.amazon.awssdk', name: 'aws-query-protocol', version: "2.30.18"
67+
implementation group: 'software.amazon.awssdk', name: 'aws-query-protocol', version: "${versions.aws}"
6868

69-
implementation group: 'software.amazon.awssdk', name: 'protocol-core', version: "2.30.18"
69+
implementation group: 'software.amazon.awssdk', name: 'protocol-core', version: "${versions.aws}"
7070

7171
zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}"
7272
compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}"

0 commit comments

Comments
 (0)