diff --git a/common/build.gradle b/common/build.gradle index 7b142e7a70..39b29fcafe 100644 --- a/common/build.gradle +++ b/common/build.gradle @@ -43,6 +43,10 @@ dependencies { compileOnly group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0' // Multi-tenant SDK Client compileOnly "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}" + compileOnly (group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "${versions.aws}") { + exclude(group: 'org.reactivestreams', module: 'reactive-streams') + exclude(group: 'org.slf4j', module: 'slf4j-api') + } } lombok { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java b/common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java similarity index 58% rename from ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java rename to common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java index ffc95c30de..109a5de5f8 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactory.java +++ b/common/src/main/java/org/opensearch/ml/common/httpclient/MLHttpClientFactory.java @@ -3,19 +3,18 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.engine.httpclient; +package org.opensearch.ml.common.httpclient; import java.net.Inet4Address; import java.net.InetAddress; import java.net.UnknownHostException; -import java.security.AccessController; -import java.security.PrivilegedActionException; -import java.security.PrivilegedExceptionAction; import java.time.Duration; import java.util.Arrays; import java.util.Locale; import java.util.concurrent.atomic.AtomicBoolean; +import org.opensearch.common.util.concurrent.ThreadContextAccess; + import lombok.extern.log4j.Log4j2; import software.amazon.awssdk.http.async.SdkAsyncHttpClient; import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient; @@ -24,19 +23,15 @@ public class MLHttpClientFactory { public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout, Duration readTimeout, int maxConnections) { - try { - return AccessController - .doPrivileged( - (PrivilegedExceptionAction) () -> NettyNioAsyncHttpClient - .builder() - .connectionTimeout(connectionTimeout) - .readTimeout(readTimeout) - .maxConcurrency(maxConnections) - .build() - ); - } catch (PrivilegedActionException e) { - return null; - } + return ThreadContextAccess + .doPrivileged( + () -> NettyNioAsyncHttpClient + .builder() + .connectionTimeout(connectionTimeout) + .readTimeout(readTimeout) + .maxConcurrency(maxConnections) + .build() + ); } /** @@ -50,7 +45,7 @@ public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout, public static void validate(String protocol, String host, int port, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException { if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) { - log.error("Remote inference protocol is not http or https: " + protocol); + log.error("Remote inference protocol is not http or https: {}", protocol); throw new IllegalArgumentException("Protocol is not http or https: " + protocol); } // When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol. @@ -62,7 +57,7 @@ public static void validate(String protocol, String host, int port, AtomicBoolea } } if (port < 0 || port > 65536) { - log.error("Remote inference port out of range: " + port); + log.error("Remote inference port out of range: {}", port); throw new IllegalArgumentException("Port out of range: " + port); } validateIp(host, connectorPrivateIpEnabled); @@ -71,7 +66,7 @@ public static void validate(String protocol, String host, int port, AtomicBoolea private static void validateIp(String hostName, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException { InetAddress[] addresses = InetAddress.getAllByName(hostName); if ((connectorPrivateIpEnabled == null || !connectorPrivateIpEnabled.get()) && hasPrivateIpAddress(addresses)) { - log.error("Remote inference host name has private ip address: " + hostName); + log.error("Remote inference host name has private ip address: {}", hostName); throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName); } } @@ -83,23 +78,8 @@ private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) { if (bytes.length != 4) { return true; } else { - int firstOctets = bytes[0] & 0xff; - int firstInOctal = parseWithOctal(String.valueOf(firstOctets)); - int firstInHex = Integer.parseInt(String.valueOf(firstOctets), 16); - if (firstInOctal == 127 || firstInHex == 127) { - return bytes[1] == 0 && bytes[2] == 0 && bytes[3] == 1; - } else if (firstInOctal == 10 || firstInHex == 10) { + if (isPrivateIPv4(bytes)) { return true; - } else if (firstInOctal == 172 || firstInHex == 172) { - int secondOctets = bytes[1] & 0xff; - int secondInOctal = parseWithOctal(String.valueOf(secondOctets)); - int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16); - return (secondInOctal >= 16 && secondInOctal <= 32) || (secondInHex >= 16 && secondInHex <= 32); - } else if (firstInOctal == 192 || firstInHex == 192) { - int secondOctets = bytes[1] & 0xff; - int secondInOctal = parseWithOctal(String.valueOf(secondOctets)); - int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16); - return secondInOctal == 168 || secondInHex == 168; } } } @@ -107,11 +87,14 @@ private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) { return Arrays.stream(ipAddress).anyMatch(x -> x.isSiteLocalAddress() || x.isLoopbackAddress() || x.isAnyLocalAddress()); } - private static int parseWithOctal(String input) { - try { - return Integer.parseInt(input, 8); - } catch (NumberFormatException e) { - return Integer.parseInt(input); - } + private static boolean isPrivateIPv4(byte[] bytes) { + int first = bytes[0] & 0xff; + int second = bytes[1] & 0xff; + + // 127.0.0.1, 10.x.x.x, 172.16-31.x.x, 192.168.x.x, 169.254.x.x + return (first == 10) + || (first == 172 && second >= 16 && second <= 31) + || (first == 192 && second == 168) + || (first == 169 && second == 254); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java b/common/src/test/java/org/opensearch/ml/common/httpclient/MLHttpClientFactoryTests.java similarity index 76% rename from ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java rename to common/src/test/java/org/opensearch/ml/common/httpclient/MLHttpClientFactoryTests.java index ca626158b2..1c01172344 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/httpclient/MLHttpClientFactoryTests.java +++ b/common/src/test/java/org/opensearch/ml/common/httpclient/MLHttpClientFactoryTests.java @@ -3,9 +3,11 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.ml.engine.httpclient; +package org.opensearch.ml.common.httpclient; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThrows; import java.time.Duration; import java.util.concurrent.atomic.AtomicBoolean; @@ -33,6 +35,39 @@ public void test_getSdkAsyncHttpClient_success() { assertNotNull(client); } + @Test + public void test_invalidIP_localHost_privateIPDisabled() { + IllegalArgumentException e1 = assertThrows( + IllegalArgumentException.class, + () -> MLHttpClientFactory.validate(HTTP, "127.0.0.1", 80, PRIVATE_IP_DISABLED) + ); + assertEquals("Remote inference host name has private ip address: 127.0.0.1", e1.getMessage()); + + IllegalArgumentException e2 = assertThrows( + IllegalArgumentException.class, + () -> MLHttpClientFactory.validate(HTTP, "192.168.0.1", 80, PRIVATE_IP_DISABLED) + ); + assertEquals("Remote inference host name has private ip address: 192.168.0.1", e2.getMessage()); + + IllegalArgumentException e3 = assertThrows( + IllegalArgumentException.class, + () -> MLHttpClientFactory.validate(HTTP, "169.254.0.1", 80, PRIVATE_IP_DISABLED) + ); + assertEquals("Remote inference host name has private ip address: 169.254.0.1", e3.getMessage()); + + IllegalArgumentException e4 = assertThrows( + IllegalArgumentException.class, + () -> MLHttpClientFactory.validate(HTTP, "172.16.0.1", 80, PRIVATE_IP_DISABLED) + ); + assertEquals("Remote inference host name has private ip address: 172.16.0.1", e4.getMessage()); + + IllegalArgumentException e5 = assertThrows( + IllegalArgumentException.class, + () -> MLHttpClientFactory.validate(HTTP, "172.31.0.1", 80, PRIVATE_IP_DISABLED) + ); + assertEquals("Remote inference host name has private ip address: 172.31.0.1", e5.getMessage()); + } + @Test public void test_validateIp_validIp_noException() throws Exception { MLHttpClientFactory.validate(HTTP, TEST_HOST, 80, PRIVATE_IP_DISABLED); @@ -43,6 +78,8 @@ public void test_validateIp_validIp_noException() throws Exception { MLHttpClientFactory.validate(HTTP, "177.0.1.1", 80, PRIVATE_IP_DISABLED); MLHttpClientFactory.validate(HTTP, "177.0.0.2", 80, PRIVATE_IP_DISABLED); MLHttpClientFactory.validate(HTTP, "::ffff", 80, PRIVATE_IP_DISABLED); + MLHttpClientFactory.validate(HTTP, "172.32.0.1", 80, PRIVATE_IP_ENABLED); + MLHttpClientFactory.validate(HTTP, "172.2097152", 80, PRIVATE_IP_ENABLED); } @Test diff --git a/ml-algorithms/build.gradle b/ml-algorithms/build.gradle index 461dba94c7..d0748e4175 100644 --- a/ml-algorithms/build.gradle +++ b/ml-algorithms/build.gradle @@ -68,10 +68,10 @@ dependencies { } } - implementation platform('software.amazon.awssdk:bom:2.30.18') - api 'software.amazon.awssdk:auth:2.30.18' - implementation 'software.amazon.awssdk:apache-client' - implementation ('software.amazon.awssdk:bedrockruntime') { + implementation platform(group: 'software.amazon.awssdk', name: 'bom', version:"${versions.aws}") + api 'software.amazon.awssdk:auth:${versions.aws}' + implementation group: 'software.amazon.awssdk', name:'apache-client', version: "${versions.aws}" + implementation (group: 'software.amazon.awssdk', name: 'bedrockruntime', version: "${versions.aws}") { exclude group: 'io.netty' } implementation ('com.amazonaws:aws-encryption-sdk-java:2.4.1') { @@ -79,16 +79,16 @@ dependencies { } // needed by aws-encryption-sdk-java implementation "org.bouncycastle:bc-fips:${versions.bouncycastle_jce}" - compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: "2.30.18" - compileOnly group: 'software.amazon.awssdk', name: 's3', version: "2.30.18" - compileOnly group: 'software.amazon.awssdk', name: 'regions', version: "2.30.18" + compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: "${versions.aws}" + compileOnly group: 'software.amazon.awssdk', name: 's3', version: "${versions.aws}" + compileOnly group: 'software.amazon.awssdk', name: 'regions', version: "${versions.aws}" implementation ('com.jayway.jsonpath:json-path:2.9.0') { exclude group: 'net.minidev', module: 'json-smart' } implementation('net.minidev:json-smart:2.5.2') implementation group: 'org.json', name: 'json', version: '20231013' - implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "2.30.18" + implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "${versions.aws}" api('io.modelcontextprotocol.sdk:mcp:0.12.1') testImplementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}") testImplementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}") @@ -104,7 +104,7 @@ lombok { configurations.all { resolutionStrategy.force 'com.google.protobuf:protobuf-java:3.25.5' resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0' - resolutionStrategy.force 'software.amazon.awssdk:bom:2.30.18' + resolutionStrategy.force group: 'software.amazon.awssdk', name:'bom', version:"${versions.aws}" } jacocoTestReport { diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java index 7012d057a5..484058d550 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/AwsConnectorExecutor.java @@ -28,12 +28,12 @@ import org.opensearch.ml.common.connector.AwsConnector; import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.httpclient.MLHttpClientFactory; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.engine.annotation.ConnectorExecutor; -import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; import org.opensearch.script.ScriptService; import org.opensearch.transport.StreamTransportService; import org.opensearch.transport.client.Client; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java index ec5b18f28c..6a408ba61d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/remote/HttpJsonConnectorExecutor.java @@ -32,13 +32,13 @@ import org.opensearch.ml.common.connector.Connector; import org.opensearch.ml.common.connector.HttpConnector; import org.opensearch.ml.common.exception.MLException; +import org.opensearch.ml.common.httpclient.MLHttpClientFactory; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.model.MLGuard; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.utils.StringUtils; import org.opensearch.ml.engine.annotation.ConnectorExecutor; -import org.opensearch.ml.engine.httpclient.MLHttpClientFactory; import org.opensearch.script.ScriptService; import org.opensearch.transport.StreamTransportService; import org.opensearch.transport.client.Client; diff --git a/plugin/build.gradle b/plugin/build.gradle index 16487b9544..5245df707b 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -58,15 +58,15 @@ dependencies { implementation project(':opensearch-ml-search-processors') implementation project(':opensearch-ml-memory') - implementation group: 'software.amazon.awssdk', name: 'aws-core', version: "2.30.18" - implementation group: 'software.amazon.awssdk', name: 's3', version: "2.30.18" - implementation group: 'software.amazon.awssdk', name: 'regions', version: "2.30.18" + implementation group: 'software.amazon.awssdk', name: 'aws-core', version: "${versions.aws}" + implementation group: 'software.amazon.awssdk', name: 's3', version: "${versions.aws}" + implementation group: 'software.amazon.awssdk', name: 'regions', version: "${versions.aws}" - implementation group: 'software.amazon.awssdk', name: 'aws-xml-protocol', version: "2.30.18" + implementation group: 'software.amazon.awssdk', name: 'aws-xml-protocol', version: "${versions.aws}" - implementation group: 'software.amazon.awssdk', name: 'aws-query-protocol', version: "2.30.18" + implementation group: 'software.amazon.awssdk', name: 'aws-query-protocol', version: "${versions.aws}" - implementation group: 'software.amazon.awssdk', name: 'protocol-core', version: "2.30.18" + implementation group: 'software.amazon.awssdk', name: 'protocol-core', version: "${versions.aws}" zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}" compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}"