Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions common/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ dependencies {
compileOnly group: 'com.networknt' , name: 'json-schema-validator', version: '1.4.0'
// Multi-tenant SDK Client
compileOnly "org.opensearch:opensearch-remote-metadata-sdk:${opensearch_build}"
compileOnly (group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "${versions.aws}") {
exclude(group: 'org.reactivestreams', module: 'reactive-streams')
exclude(group: 'org.slf4j', module: 'slf4j-api')
}
}

lombok {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,18 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.httpclient;
package org.opensearch.ml.common.httpclient;

import java.net.Inet4Address;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.security.AccessController;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.time.Duration;
import java.util.Arrays;
import java.util.Locale;
import java.util.concurrent.atomic.AtomicBoolean;

import org.opensearch.common.util.concurrent.ThreadContextAccess;

import lombok.extern.log4j.Log4j2;
import software.amazon.awssdk.http.async.SdkAsyncHttpClient;
import software.amazon.awssdk.http.nio.netty.NettyNioAsyncHttpClient;
Expand All @@ -24,19 +23,15 @@
public class MLHttpClientFactory {

public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout, Duration readTimeout, int maxConnections) {
try {
return AccessController
.doPrivileged(
(PrivilegedExceptionAction<SdkAsyncHttpClient>) () -> NettyNioAsyncHttpClient
.builder()
.connectionTimeout(connectionTimeout)
.readTimeout(readTimeout)
.maxConcurrency(maxConnections)
.build()
);
} catch (PrivilegedActionException e) {
return null;
}
return ThreadContextAccess
.doPrivileged(
() -> NettyNioAsyncHttpClient
.builder()
.connectionTimeout(connectionTimeout)
.readTimeout(readTimeout)
.maxConcurrency(maxConnections)
.build()
);
}

/**
Expand All @@ -50,7 +45,7 @@ public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout,
public static void validate(String protocol, String host, int port, AtomicBoolean connectorPrivateIpEnabled)
throws UnknownHostException {
if (protocol != null && !"http".equalsIgnoreCase(protocol) && !"https".equalsIgnoreCase(protocol)) {
log.error("Remote inference protocol is not http or https: " + protocol);
log.error("Remote inference protocol is not http or https: {}", protocol);
throw new IllegalArgumentException("Protocol is not http or https: " + protocol);
}
// When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol.
Expand All @@ -62,7 +57,7 @@ public static void validate(String protocol, String host, int port, AtomicBoolea
}
}
if (port < 0 || port > 65536) {
log.error("Remote inference port out of range: " + port);
log.error("Remote inference port out of range: {}", port);
throw new IllegalArgumentException("Port out of range: " + port);
}
validateIp(host, connectorPrivateIpEnabled);
Expand All @@ -71,7 +66,7 @@ public static void validate(String protocol, String host, int port, AtomicBoolea
private static void validateIp(String hostName, AtomicBoolean connectorPrivateIpEnabled) throws UnknownHostException {
InetAddress[] addresses = InetAddress.getAllByName(hostName);
if ((connectorPrivateIpEnabled == null || !connectorPrivateIpEnabled.get()) && hasPrivateIpAddress(addresses)) {
log.error("Remote inference host name has private ip address: " + hostName);
log.error("Remote inference host name has private ip address: {}", hostName);
throw new IllegalArgumentException("Remote inference host name has private ip address: " + hostName);
}
}
Expand All @@ -83,35 +78,23 @@ private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) {
if (bytes.length != 4) {
return true;
} else {
int firstOctets = bytes[0] & 0xff;
int firstInOctal = parseWithOctal(String.valueOf(firstOctets));
int firstInHex = Integer.parseInt(String.valueOf(firstOctets), 16);
if (firstInOctal == 127 || firstInHex == 127) {
return bytes[1] == 0 && bytes[2] == 0 && bytes[3] == 1;
} else if (firstInOctal == 10 || firstInHex == 10) {
if (isPrivateIPv4(bytes)) {
return true;
} else if (firstInOctal == 172 || firstInHex == 172) {
int secondOctets = bytes[1] & 0xff;
int secondInOctal = parseWithOctal(String.valueOf(secondOctets));
int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16);
return (secondInOctal >= 16 && secondInOctal <= 32) || (secondInHex >= 16 && secondInHex <= 32);
} else if (firstInOctal == 192 || firstInHex == 192) {
int secondOctets = bytes[1] & 0xff;
int secondInOctal = parseWithOctal(String.valueOf(secondOctets));
int secondInHex = Integer.parseInt(String.valueOf(secondOctets), 16);
return secondInOctal == 168 || secondInHex == 168;
}
}
}
}
return Arrays.stream(ipAddress).anyMatch(x -> x.isSiteLocalAddress() || x.isLoopbackAddress() || x.isAnyLocalAddress());
}

private static int parseWithOctal(String input) {
try {
return Integer.parseInt(input, 8);
} catch (NumberFormatException e) {
return Integer.parseInt(input);
}
private static boolean isPrivateIPv4(byte[] bytes) {
int first = bytes[0] & 0xff;
int second = bytes[1] & 0xff;

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance concern: This method attempts octal and hex parsing for every byte comparison, which will frequently throw NumberFormatException for common IP octets (192, 168, 127, etc.) since they contain invalid octal digits. Consider using simple decimal comparison since IP addresses are standardized as decimal.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested, these parsing are indeed not necessary as when byte[] bytes = ip.getAddress(); runs, it already parsed the strange ip address to correct numbers, so those parsing code are not get executed.

// 127.0.0.1, 10.x.x.x, 172.16-31.x.x, 192.168.x.x, 169.254.x.x
return (first == 10)
|| (first == 172 && second >= 16 && second <= 31)
|| (first == 192 && second == 168)
|| (first == 169 && second == 254);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.httpclient;
package org.opensearch.ml.common.httpclient;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThrows;

import java.time.Duration;
import java.util.concurrent.atomic.AtomicBoolean;
Expand Down Expand Up @@ -33,6 +35,39 @@ public void test_getSdkAsyncHttpClient_success() {
assertNotNull(client);
}

@Test
public void test_invalidIP_localHost_privateIPDisabled() {
IllegalArgumentException e1 = assertThrows(
IllegalArgumentException.class,
() -> MLHttpClientFactory.validate(HTTP, "127.0.0.1", 80, PRIVATE_IP_DISABLED)
);
assertEquals("Remote inference host name has private ip address: 127.0.0.1", e1.getMessage());

IllegalArgumentException e2 = assertThrows(
IllegalArgumentException.class,
() -> MLHttpClientFactory.validate(HTTP, "192.168.0.1", 80, PRIVATE_IP_DISABLED)
);
assertEquals("Remote inference host name has private ip address: 192.168.0.1", e2.getMessage());

IllegalArgumentException e3 = assertThrows(
IllegalArgumentException.class,
() -> MLHttpClientFactory.validate(HTTP, "169.254.0.1", 80, PRIVATE_IP_DISABLED)
);
assertEquals("Remote inference host name has private ip address: 169.254.0.1", e3.getMessage());

IllegalArgumentException e4 = assertThrows(
IllegalArgumentException.class,
() -> MLHttpClientFactory.validate(HTTP, "172.16.0.1", 80, PRIVATE_IP_DISABLED)
);
assertEquals("Remote inference host name has private ip address: 172.16.0.1", e4.getMessage());

IllegalArgumentException e5 = assertThrows(
IllegalArgumentException.class,
() -> MLHttpClientFactory.validate(HTTP, "172.31.0.1", 80, PRIVATE_IP_DISABLED)
);
assertEquals("Remote inference host name has private ip address: 172.31.0.1", e5.getMessage());
}

@Test
public void test_validateIp_validIp_noException() throws Exception {
MLHttpClientFactory.validate(HTTP, TEST_HOST, 80, PRIVATE_IP_DISABLED);
Expand All @@ -43,6 +78,8 @@ public void test_validateIp_validIp_noException() throws Exception {
MLHttpClientFactory.validate(HTTP, "177.0.1.1", 80, PRIVATE_IP_DISABLED);
MLHttpClientFactory.validate(HTTP, "177.0.0.2", 80, PRIVATE_IP_DISABLED);
MLHttpClientFactory.validate(HTTP, "::ffff", 80, PRIVATE_IP_DISABLED);
MLHttpClientFactory.validate(HTTP, "172.32.0.1", 80, PRIVATE_IP_ENABLED);
MLHttpClientFactory.validate(HTTP, "172.2097152", 80, PRIVATE_IP_ENABLED);
Copy link
Contributor

@brianf-aws brianf-aws Sep 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this 172.2097152 a spelling mistake if so how come it didn't throw an exception? We may need to tune the logic to catch these edge cases

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the 172.32.0.0 representation, which is not a private ip, so it doesn't throw exception.

}

@Test
Expand Down
18 changes: 9 additions & 9 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -68,27 +68,27 @@ dependencies {
}
}

implementation platform('software.amazon.awssdk:bom:2.30.18')
api 'software.amazon.awssdk:auth:2.30.18'
implementation 'software.amazon.awssdk:apache-client'
implementation ('software.amazon.awssdk:bedrockruntime') {
implementation platform(group: 'software.amazon.awssdk', name: 'bom', version:"${versions.aws}")
api 'software.amazon.awssdk:auth:${versions.aws}'
implementation group: 'software.amazon.awssdk', name:'apache-client', version: "${versions.aws}"
implementation (group: 'software.amazon.awssdk', name: 'bedrockruntime', version: "${versions.aws}") {
exclude group: 'io.netty'
}
implementation ('com.amazonaws:aws-encryption-sdk-java:2.4.1') {
exclude group: 'org.bouncycastle', module: 'bcprov-ext-jdk18on'
}
// needed by aws-encryption-sdk-java
implementation "org.bouncycastle:bc-fips:${versions.bouncycastle_jce}"
compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: "2.30.18"
compileOnly group: 'software.amazon.awssdk', name: 's3', version: "2.30.18"
compileOnly group: 'software.amazon.awssdk', name: 'regions', version: "2.30.18"
compileOnly group: 'software.amazon.awssdk', name: 'aws-core', version: "${versions.aws}"
compileOnly group: 'software.amazon.awssdk', name: 's3', version: "${versions.aws}"
compileOnly group: 'software.amazon.awssdk', name: 'regions', version: "${versions.aws}"

implementation ('com.jayway.jsonpath:json-path:2.9.0') {
exclude group: 'net.minidev', module: 'json-smart'
}
implementation('net.minidev:json-smart:2.5.2')
implementation group: 'org.json', name: 'json', version: '20231013'
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "2.30.18"
implementation group: 'software.amazon.awssdk', name: 'netty-nio-client', version: "${versions.aws}"
api('io.modelcontextprotocol.sdk:mcp:0.12.1')
testImplementation("com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}")
testImplementation("com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}")
Expand All @@ -104,7 +104,7 @@ lombok {
configurations.all {
resolutionStrategy.force 'com.google.protobuf:protobuf-java:3.25.5'
resolutionStrategy.force 'org.apache.commons:commons-compress:1.26.0'
resolutionStrategy.force 'software.amazon.awssdk:bom:2.30.18'
resolutionStrategy.force group: 'software.amazon.awssdk', name:'bom', version:"${versions.aws}"
}

jacocoTestReport {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
import org.opensearch.ml.common.connector.AwsConnector;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.httpclient.MLHttpClientFactory;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
import org.opensearch.script.ScriptService;
import org.opensearch.transport.StreamTransportService;
import org.opensearch.transport.client.Client;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.connector.HttpConnector;
import org.opensearch.ml.common.exception.MLException;
import org.opensearch.ml.common.httpclient.MLHttpClientFactory;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.model.MLGuard;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.annotation.ConnectorExecutor;
import org.opensearch.ml.engine.httpclient.MLHttpClientFactory;
import org.opensearch.script.ScriptService;
import org.opensearch.transport.StreamTransportService;
import org.opensearch.transport.client.Client;
Expand Down
12 changes: 6 additions & 6 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,15 @@ dependencies {
implementation project(':opensearch-ml-search-processors')
implementation project(':opensearch-ml-memory')

implementation group: 'software.amazon.awssdk', name: 'aws-core', version: "2.30.18"
implementation group: 'software.amazon.awssdk', name: 's3', version: "2.30.18"
implementation group: 'software.amazon.awssdk', name: 'regions', version: "2.30.18"
implementation group: 'software.amazon.awssdk', name: 'aws-core', version: "${versions.aws}"
implementation group: 'software.amazon.awssdk', name: 's3', version: "${versions.aws}"
implementation group: 'software.amazon.awssdk', name: 'regions', version: "${versions.aws}"

implementation group: 'software.amazon.awssdk', name: 'aws-xml-protocol', version: "2.30.18"
implementation group: 'software.amazon.awssdk', name: 'aws-xml-protocol', version: "${versions.aws}"

implementation group: 'software.amazon.awssdk', name: 'aws-query-protocol', version: "2.30.18"
implementation group: 'software.amazon.awssdk', name: 'aws-query-protocol', version: "${versions.aws}"

implementation group: 'software.amazon.awssdk', name: 'protocol-core', version: "2.30.18"
implementation group: 'software.amazon.awssdk', name: 'protocol-core', version: "${versions.aws}"

zipArchive group: 'org.opensearch.plugin', name:'opensearch-job-scheduler', version: "${opensearch_build}"
compileOnly "org.opensearch:opensearch-job-scheduler-spi:${opensearch_build}"
Expand Down
Loading