Skip to content

Commit 9d2556a

Browse files
authored
[cherry-pick] support connectType if use oss-bucket (#2988)
Signed-off-by: lentitude2tk <[email protected]>
1 parent 8b09a7c commit 9d2556a

File tree

6 files changed

+107
-6
lines changed

6 files changed

+107
-6
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from pymilvus.bulk_writer.constants import ConnectType
12
from pymilvus.bulk_writer.stage_file_manager import StageFileManager
23

34
if __name__ == "__main__":
45
stage_file_manager = StageFileManager(
56
cloud_endpoint='https://api.cloud.zilliz.com',
67
api_key='_api_key_for_cluster_org_',
78
stage_name='_stage_name_for_project_',
9+
connect_type=ConnectType.AUTO,
810
)
911
result = stage_file_manager.upload_file_to_stage("/Users/zilliz/data/", "data/")
1012
print(f"\nuploadFileToStage results: {result}")

pymilvus/bulk_writer/constants.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# or implied. See the License for the specific language governing permissions and limitations under
1111
# the License.
1212

13-
from enum import IntEnum
13+
from enum import Enum, IntEnum
1414

1515
import numpy as np
1616

@@ -88,3 +88,9 @@ class BulkFileType(IntEnum):
8888
JSON_RB = 2 # deprecated
8989
PARQUET = 3
9090
CSV = 4
91+
92+
93+
class ConnectType(Enum):
94+
AUTO = "AUTO"
95+
INTERNAL = "INTERNAL"
96+
PUBLIC = "PUBLIC"
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import http
2+
import logging
3+
4+
from pymilvus.bulk_writer.constants import ConnectType
5+
6+
logger = logging.getLogger("EndpointResolver")
7+
logging.basicConfig(level=logging.INFO)
8+
9+
10+
class EndpointResolver:
11+
@staticmethod
12+
def resolve_endpoint(
13+
default_endpoint: str, cloud: str, region: str, connect_type: ConnectType
14+
) -> str:
15+
logger.info(
16+
"Start resolving endpoint, cloud=%s, region=%s, connectType=%s",
17+
cloud,
18+
region,
19+
connect_type,
20+
)
21+
if cloud == "ali":
22+
default_endpoint = EndpointResolver._resolve_oss_endpoint(region, connect_type)
23+
logger.info("Resolved endpoint: %s, reachable check passed", default_endpoint)
24+
return default_endpoint
25+
26+
@staticmethod
27+
def _resolve_oss_endpoint(region: str, connect_type: ConnectType) -> str:
28+
internal_endpoint = f"oss-{region}-internal.aliyuncs.com"
29+
public_endpoint = f"oss-{region}.aliyuncs.com"
30+
31+
if connect_type == ConnectType.INTERNAL:
32+
logger.info("Forced INTERNAL endpoint selected: %s", internal_endpoint)
33+
EndpointResolver._check_endpoint_reachable(internal_endpoint, True)
34+
return internal_endpoint
35+
if connect_type == ConnectType.PUBLIC:
36+
logger.info("Forced PUBLIC endpoint selected: %s", public_endpoint)
37+
EndpointResolver._check_endpoint_reachable(public_endpoint, True)
38+
return public_endpoint
39+
if EndpointResolver._check_endpoint_reachable(internal_endpoint, False):
40+
logger.info("AUTO mode: internal endpoint reachable, using %s", internal_endpoint)
41+
return internal_endpoint
42+
logger.warning(
43+
"AUTO mode: internal endpoint not reachable, fallback to public endpoint %s",
44+
public_endpoint,
45+
)
46+
EndpointResolver._check_endpoint_reachable(public_endpoint, True)
47+
return public_endpoint
48+
49+
@staticmethod
50+
def _check_endpoint_reachable(endpoint: str, raise_error: bool) -> bool:
51+
try:
52+
conn = http.client.HTTPSConnection(endpoint, timeout=5)
53+
conn.request("HEAD", "/")
54+
resp = conn.getresponse()
55+
code = resp.status
56+
logger.debug("Checked endpoint %s, response code=%s", endpoint, code)
57+
except Exception as e:
58+
if raise_error:
59+
logger.exception("Endpoint %s not reachable, throwing exception", endpoint)
60+
raise RuntimeError(str(e)) from e
61+
logger.warning("Endpoint %s not reachable, will fallback if needed", endpoint)
62+
return False
63+
else:
64+
return 200 <= code < 400

pymilvus/bulk_writer/stage_bulk_writer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from pymilvus.bulk_writer.stage_file_manager import StageFileManager
77
from pymilvus.orm.schema import CollectionSchema
88

9-
from .constants import MB, BulkFileType
9+
from .constants import MB, BulkFileType, ConnectType
1010
from .local_bulk_writer import LocalBulkWriter
1111

1212
logger = logging.getLogger(__name__)
@@ -35,7 +35,10 @@ def __init__(
3535
self._remote_files: List[List[str]] = []
3636
self._stage_name = stage_name
3737
self._stage_file_manager = StageFileManager(
38-
cloud_endpoint=cloud_endpoint, api_key=api_key, stage_name=stage_name
38+
cloud_endpoint=cloud_endpoint,
39+
api_key=api_key,
40+
stage_name=stage_name,
41+
connect_type=ConnectType.AUTO,
3942
)
4043

4144
logger.info(f"Remote buffer writer initialized, target path: {self._remote_path}")

pymilvus/bulk_writer/stage_file_manager.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from minio import Minio
1010
from minio.error import S3Error
1111

12+
from pymilvus.bulk_writer.constants import ConnectType
13+
from pymilvus.bulk_writer.endpoint_resolver import EndpointResolver
1214
from pymilvus.bulk_writer.file_utils import FileUtils
1315
from pymilvus.bulk_writer.stage_restful import apply_stage
1416

@@ -17,7 +19,13 @@
1719

1820

1921
class StageFileManager:
20-
def __init__(self, cloud_endpoint: str, api_key: str, stage_name: str):
22+
def __init__(
23+
self,
24+
cloud_endpoint: str,
25+
api_key: str,
26+
stage_name: str,
27+
connect_type: ConnectType = ConnectType.AUTO,
28+
):
2129
"""
2230
private preview feature. Please submit a request and contact us if you need it.
2331
@@ -27,10 +35,16 @@ def __init__(self, cloud_endpoint: str, api_key: str, stage_name: str):
2735
- For regions in China: https://api.cloud.zilliz.com.cn
2836
api_key (str): The API key associated with your organization
2937
stage_name (str): The name of the Stage.
38+
connect_type: Current value is mainly for Aliyun OSS buckets, default is Auto.
39+
- Default case, if the OSS bucket is reachable via the internal endpoint,
40+
the internal endpoint will be used
41+
- otherwise, the public endpoint will be used.
42+
- You can also force the use of either the internal or public endpoint.
3043
"""
3144
self.cloud_endpoint = cloud_endpoint
3245
self.api_key = api_key
3346
self.stage_name = stage_name
47+
self.connect_type = connect_type
3448
self.local_file_paths = []
3549
self.total_bytes = 0
3650
self.stage_info = {}
@@ -51,8 +65,15 @@ def _refresh_stage_and_client(self, path: str):
5165

5266
creds = self.stage_info["credentials"]
5367
http_client = urllib3.PoolManager(maxsize=100)
68+
69+
endpoint = EndpointResolver.resolve_endpoint(
70+
self.stage_info["endpoint"],
71+
self.stage_info["cloud"],
72+
self.stage_info["region"],
73+
self.connect_type,
74+
)
5475
self._client = Minio(
55-
endpoint=self.stage_info["endpoint"],
76+
endpoint=endpoint,
5677
access_key=creds["tmpAK"],
5778
secret_key=creds["tmpSK"],
5879
session_token=creds["sessionToken"],

tests/test_bulk_writer_stage.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pytest
77
import requests
8-
from pymilvus.bulk_writer.constants import BulkFileType
8+
from pymilvus.bulk_writer.constants import BulkFileType, ConnectType
99
from pymilvus.bulk_writer.stage_bulk_writer import StageBulkWriter
1010
from pymilvus.bulk_writer.stage_file_manager import StageFileManager
1111
from pymilvus.bulk_writer.stage_manager import StageManager
@@ -127,6 +127,7 @@ def test_apply_stage_success(
127127
"endpoint": "s3.amazonaws.com",
128128
"bucketName": "test-bucket",
129129
"region": "us-west-2",
130+
"cloud": "aws",
130131
"condition": {"maxContentLength": 1073741824},
131132
"credentials": {
132133
"tmpAK": "test_access_key",
@@ -238,6 +239,7 @@ def stage_file_manager(self) -> StageFileManager:
238239
cloud_endpoint="https://api.cloud.zilliz.com",
239240
api_key="test_api_key",
240241
stage_name="test_stage",
242+
connect_type=ConnectType.AUTO,
241243
)
242244

243245
@pytest.fixture
@@ -249,6 +251,7 @@ def mock_stage_info(self) -> Dict[str, Any]:
249251
"endpoint": "s3.amazonaws.com",
250252
"bucketName": "test-bucket",
251253
"region": "us-west-2",
254+
"cloud": "aws",
252255
"condition": {"maxContentLength": 1073741824},
253256
"credentials": {
254257
"tmpAK": "test_access_key",
@@ -563,6 +566,7 @@ def mock_server_responses(self) -> Dict[str, Any]:
563566
"endpoint": "s3.amazonaws.com",
564567
"bucketName": "test-bucket",
565568
"region": "us-west-2",
569+
"cloud": "aws",
566570
"condition": {"maxContentLength": 1073741824},
567571
"credentials": {
568572
"tmpAK": "test_access_key",
@@ -614,6 +618,7 @@ def test_full_stage_workflow(
614618
cloud_endpoint="https://api.cloud.zilliz.com",
615619
api_key="test_api_key",
616620
stage_name="test_stage",
621+
connect_type=ConnectType.AUTO,
617622
)
618623

619624
# Verify stage info can be refreshed

0 commit comments

Comments
 (0)