33 * SPDX-License-Identifier: Apache-2.0
44 */
55
6- package org .opensearch .ml .engine .httpclient ;
6+ package org .opensearch .ml .common .httpclient ;
77
88import java .net .Inet4Address ;
99import java .net .InetAddress ;
1010import java .net .UnknownHostException ;
11- import java .security .AccessController ;
12- import java .security .PrivilegedActionException ;
13- import java .security .PrivilegedExceptionAction ;
1411import java .time .Duration ;
1512import java .util .Arrays ;
1613import java .util .Locale ;
1714import java .util .concurrent .atomic .AtomicBoolean ;
1815
16+ import org .opensearch .common .util .concurrent .ThreadContextAccess ;
17+
1918import lombok .extern .log4j .Log4j2 ;
2019import software .amazon .awssdk .http .async .SdkAsyncHttpClient ;
2120import software .amazon .awssdk .http .nio .netty .NettyNioAsyncHttpClient ;
2423public class MLHttpClientFactory {
2524
2625 public static SdkAsyncHttpClient getAsyncHttpClient (Duration connectionTimeout , Duration readTimeout , int maxConnections ) {
27- try {
28- return AccessController
29- .doPrivileged (
30- (PrivilegedExceptionAction <SdkAsyncHttpClient >) () -> NettyNioAsyncHttpClient
31- .builder ()
32- .connectionTimeout (connectionTimeout )
33- .readTimeout (readTimeout )
34- .maxConcurrency (maxConnections )
35- .build ()
36- );
37- } catch (PrivilegedActionException e ) {
38- return null ;
39- }
26+ return ThreadContextAccess
27+ .doPrivileged (
28+ () -> NettyNioAsyncHttpClient
29+ .builder ()
30+ .connectionTimeout (connectionTimeout )
31+ .readTimeout (readTimeout )
32+ .maxConcurrency (maxConnections )
33+ .build ()
34+ );
4035 }
4136
4237 /**
@@ -50,7 +45,7 @@ public static SdkAsyncHttpClient getAsyncHttpClient(Duration connectionTimeout,
5045 public static void validate (String protocol , String host , int port , AtomicBoolean connectorPrivateIpEnabled )
5146 throws UnknownHostException {
5247 if (protocol != null && !"http" .equalsIgnoreCase (protocol ) && !"https" .equalsIgnoreCase (protocol )) {
53- log .error ("Remote inference protocol is not http or https: " + protocol );
48+ log .error ("Remote inference protocol is not http or https: {}" , protocol );
5449 throw new IllegalArgumentException ("Protocol is not http or https: " + protocol );
5550 }
5651 // When port is not specified, the default port is -1, and we need to set it to 80 or 443 based on protocol.
@@ -62,7 +57,7 @@ public static void validate(String protocol, String host, int port, AtomicBoolea
6257 }
6358 }
6459 if (port < 0 || port > 65536 ) {
65- log .error ("Remote inference port out of range: " + port );
60+ log .error ("Remote inference port out of range: {}" , port );
6661 throw new IllegalArgumentException ("Port out of range: " + port );
6762 }
6863 validateIp (host , connectorPrivateIpEnabled );
@@ -71,7 +66,7 @@ public static void validate(String protocol, String host, int port, AtomicBoolea
7166 private static void validateIp (String hostName , AtomicBoolean connectorPrivateIpEnabled ) throws UnknownHostException {
7267 InetAddress [] addresses = InetAddress .getAllByName (hostName );
7368 if ((connectorPrivateIpEnabled == null || !connectorPrivateIpEnabled .get ()) && hasPrivateIpAddress (addresses )) {
74- log .error ("Remote inference host name has private ip address: " + hostName );
69+ log .error ("Remote inference host name has private ip address: {}" , hostName );
7570 throw new IllegalArgumentException ("Remote inference host name has private ip address: " + hostName );
7671 }
7772 }
@@ -83,35 +78,23 @@ private static boolean hasPrivateIpAddress(InetAddress[] ipAddress) {
8378 if (bytes .length != 4 ) {
8479 return true ;
8580 } else {
86- int firstOctets = bytes [0 ] & 0xff ;
87- int firstInOctal = parseWithOctal (String .valueOf (firstOctets ));
88- int firstInHex = Integer .parseInt (String .valueOf (firstOctets ), 16 );
89- if (firstInOctal == 127 || firstInHex == 127 ) {
90- return bytes [1 ] == 0 && bytes [2 ] == 0 && bytes [3 ] == 1 ;
91- } else if (firstInOctal == 10 || firstInHex == 10 ) {
81+ if (isPrivateIPv4 (bytes )) {
9282 return true ;
93- } else if (firstInOctal == 172 || firstInHex == 172 ) {
94- int secondOctets = bytes [1 ] & 0xff ;
95- int secondInOctal = parseWithOctal (String .valueOf (secondOctets ));
96- int secondInHex = Integer .parseInt (String .valueOf (secondOctets ), 16 );
97- return (secondInOctal >= 16 && secondInOctal <= 32 ) || (secondInHex >= 16 && secondInHex <= 32 );
98- } else if (firstInOctal == 192 || firstInHex == 192 ) {
99- int secondOctets = bytes [1 ] & 0xff ;
100- int secondInOctal = parseWithOctal (String .valueOf (secondOctets ));
101- int secondInHex = Integer .parseInt (String .valueOf (secondOctets ), 16 );
102- return secondInOctal == 168 || secondInHex == 168 ;
10383 }
10484 }
10585 }
10686 }
10787 return Arrays .stream (ipAddress ).anyMatch (x -> x .isSiteLocalAddress () || x .isLoopbackAddress () || x .isAnyLocalAddress ());
10888 }
10989
110- private static int parseWithOctal (String input ) {
111- try {
112- return Integer .parseInt (input , 8 );
113- } catch (NumberFormatException e ) {
114- return Integer .parseInt (input );
115- }
90+ private static boolean isPrivateIPv4 (byte [] bytes ) {
91+ int first = bytes [0 ] & 0xff ;
92+ int second = bytes [1 ] & 0xff ;
93+
94+ // 127.0.0.1, 10.x.x.x, 172.16-31.x.x, 192.168.x.x, 169.254.x.x
95+ return (first == 10 )
96+ || (first == 172 && second >= 16 && second <= 31 )
97+ || (first == 192 && second == 168 )
98+ || (first == 169 && second == 254 );
11699 }
117100}
0 commit comments