| 
 | 1 | +import http  | 
 | 2 | +import logging  | 
 | 3 | + | 
 | 4 | +from pymilvus.bulk_writer.constants import ConnectType  | 
 | 5 | + | 
 | 6 | +logger = logging.getLogger("EndpointResolver")  | 
 | 7 | +logging.basicConfig(level=logging.INFO)  | 
 | 8 | + | 
 | 9 | + | 
 | 10 | +class EndpointResolver:  | 
 | 11 | +    @staticmethod  | 
 | 12 | +    def resolve_endpoint(  | 
 | 13 | +        default_endpoint: str, cloud: str, region: str, connect_type: ConnectType  | 
 | 14 | +    ) -> str:  | 
 | 15 | +        logger.info(  | 
 | 16 | +            "Start resolving endpoint, cloud=%s, region=%s, connectType=%s",  | 
 | 17 | +            cloud,  | 
 | 18 | +            region,  | 
 | 19 | +            connect_type,  | 
 | 20 | +        )  | 
 | 21 | +        if cloud == "ali":  | 
 | 22 | +            default_endpoint = EndpointResolver._resolve_oss_endpoint(region, connect_type)  | 
 | 23 | +        logger.info("Resolved endpoint: %s, reachable check passed", default_endpoint)  | 
 | 24 | +        return default_endpoint  | 
 | 25 | + | 
 | 26 | +    @staticmethod  | 
 | 27 | +    def _resolve_oss_endpoint(region: str, connect_type: ConnectType) -> str:  | 
 | 28 | +        internal_endpoint = f"oss-{region}-internal.aliyuncs.com"  | 
 | 29 | +        public_endpoint = f"oss-{region}.aliyuncs.com"  | 
 | 30 | + | 
 | 31 | +        if connect_type == ConnectType.INTERNAL:  | 
 | 32 | +            logger.info("Forced INTERNAL endpoint selected: %s", internal_endpoint)  | 
 | 33 | +            EndpointResolver._check_endpoint_reachable(internal_endpoint, True)  | 
 | 34 | +            return internal_endpoint  | 
 | 35 | +        if connect_type == ConnectType.PUBLIC:  | 
 | 36 | +            logger.info("Forced PUBLIC endpoint selected: %s", public_endpoint)  | 
 | 37 | +            EndpointResolver._check_endpoint_reachable(public_endpoint, True)  | 
 | 38 | +            return public_endpoint  | 
 | 39 | +        if EndpointResolver._check_endpoint_reachable(internal_endpoint, False):  | 
 | 40 | +            logger.info("AUTO mode: internal endpoint reachable, using %s", internal_endpoint)  | 
 | 41 | +            return internal_endpoint  | 
 | 42 | +        logger.warning(  | 
 | 43 | +            "AUTO mode: internal endpoint not reachable, fallback to public endpoint %s",  | 
 | 44 | +            public_endpoint,  | 
 | 45 | +        )  | 
 | 46 | +        EndpointResolver._check_endpoint_reachable(public_endpoint, True)  | 
 | 47 | +        return public_endpoint  | 
 | 48 | + | 
 | 49 | +    @staticmethod  | 
 | 50 | +    def _check_endpoint_reachable(endpoint: str, raise_error: bool) -> bool:  | 
 | 51 | +        try:  | 
 | 52 | +            conn = http.client.HTTPSConnection(endpoint, timeout=5)  | 
 | 53 | +            conn.request("HEAD", "/")  | 
 | 54 | +            resp = conn.getresponse()  | 
 | 55 | +            code = resp.status  | 
 | 56 | +            logger.debug("Checked endpoint %s, response code=%s", endpoint, code)  | 
 | 57 | +        except Exception as e:  | 
 | 58 | +            if raise_error:  | 
 | 59 | +                logger.exception("Endpoint %s not reachable, throwing exception", endpoint)  | 
 | 60 | +                raise RuntimeError(str(e)) from e  | 
 | 61 | +            logger.warning("Endpoint %s not reachable, will fallback if needed", endpoint)  | 
 | 62 | +            return False  | 
 | 63 | +        else:  | 
 | 64 | +            return 200 <= code < 400  | 
0 commit comments