@@ -34,7 +34,7 @@ def get_sagemaker_version() -> str:
3434 return SageMakerSettings ._parsed_sagemaker_version
3535
3636
37- class JumpStartModelsCache (object ):
37+ class JumpStartModelsAccessor (object ):
3838 """Static class for storing the JumpStart models cache."""
3939
4040 _cache : Optional [cache .JumpStartModelsCache ] = None
@@ -67,15 +67,17 @@ def _validate_and_mutate_region_cache_kwargs(
6767
6868 @staticmethod
6969 def _set_cache_and_region (region : str , cache_kwargs : dict ) -> None :
70- """Sets ``JumpStartModelsCache ._cache`` and ``JumpStartModelsCache ._curr_region``.
70+ """Sets ``JumpStartModelsAccessor ._cache`` and ``JumpStartModelsAccessor ._curr_region``.
7171
7272 Args:
7373 region (str): region for which to retrieve header/spec.
7474 cache_kwargs (dict): kwargs to pass to ``JumpStartModelsCache``.
7575 """
76- if JumpStartModelsCache ._cache is None or region != JumpStartModelsCache ._curr_region :
77- JumpStartModelsCache ._cache = cache .JumpStartModelsCache (region = region , ** cache_kwargs )
78- JumpStartModelsCache ._curr_region = region
76+ if JumpStartModelsAccessor ._cache is None or region != JumpStartModelsAccessor ._curr_region :
77+ JumpStartModelsAccessor ._cache = cache .JumpStartModelsCache (
78+ region = region , ** cache_kwargs
79+ )
80+ JumpStartModelsAccessor ._curr_region = region
7981
8082 @staticmethod
8183 def get_model_header (region : str , model_id : str , version : str ) -> JumpStartModelHeader :
@@ -86,12 +88,12 @@ def get_model_header(region: str, model_id: str, version: str) -> JumpStartModel
8688 model_id (str): model id to retrieve.
8789 version (str): semantic version to retrieve for the model id.
8890 """
89- cache_kwargs = JumpStartModelsCache ._validate_and_mutate_region_cache_kwargs (
90- JumpStartModelsCache ._cache_kwargs , region
91+ cache_kwargs = JumpStartModelsAccessor ._validate_and_mutate_region_cache_kwargs (
92+ JumpStartModelsAccessor ._cache_kwargs , region
9193 )
92- JumpStartModelsCache ._set_cache_and_region (region , cache_kwargs )
93- assert JumpStartModelsCache ._cache is not None
94- return JumpStartModelsCache ._cache .get_header (model_id , version )
94+ JumpStartModelsAccessor ._set_cache_and_region (region , cache_kwargs )
95+ assert JumpStartModelsAccessor ._cache is not None
96+ return JumpStartModelsAccessor ._cache .get_header (model_id , version )
9597
9698 @staticmethod
9799 def get_model_specs (region : str , model_id : str , version : str ) -> JumpStartModelSpecs :
@@ -102,12 +104,12 @@ def get_model_specs(region: str, model_id: str, version: str) -> JumpStartModelS
102104 model_id (str): model id to retrieve.
103105 version (str): semantic version to retrieve for the model id.
104106 """
105- cache_kwargs = JumpStartModelsCache ._validate_and_mutate_region_cache_kwargs (
106- JumpStartModelsCache ._cache_kwargs , region
107+ cache_kwargs = JumpStartModelsAccessor ._validate_and_mutate_region_cache_kwargs (
108+ JumpStartModelsAccessor ._cache_kwargs , region
107109 )
108- JumpStartModelsCache ._set_cache_and_region (region , cache_kwargs )
109- assert JumpStartModelsCache ._cache is not None
110- return JumpStartModelsCache ._cache .get_specs (model_id , version )
110+ JumpStartModelsAccessor ._set_cache_and_region (region , cache_kwargs )
111+ assert JumpStartModelsAccessor ._cache is not None
112+ return JumpStartModelsAccessor ._cache .get_specs (model_id , version )
111113
112114 @staticmethod
113115 def set_cache_kwargs (cache_kwargs : Dict [str , Any ], region : str = None ) -> None :
@@ -120,18 +122,18 @@ def set_cache_kwargs(cache_kwargs: Dict[str, Any], region: str = None) -> None:
120122 cache_kwargs (str): cache kwargs to validate.
121123 region (str): Optional. The region to validate along with the kwargs.
122124 """
123- cache_kwargs = JumpStartModelsCache ._validate_and_mutate_region_cache_kwargs (
125+ cache_kwargs = JumpStartModelsAccessor ._validate_and_mutate_region_cache_kwargs (
124126 cache_kwargs , region
125127 )
126- JumpStartModelsCache ._cache_kwargs = cache_kwargs
128+ JumpStartModelsAccessor ._cache_kwargs = cache_kwargs
127129 if region is None :
128- JumpStartModelsCache ._cache = cache .JumpStartModelsCache (
129- ** JumpStartModelsCache ._cache_kwargs
130+ JumpStartModelsAccessor ._cache = cache .JumpStartModelsCache (
131+ ** JumpStartModelsAccessor ._cache_kwargs
130132 )
131133 else :
132- JumpStartModelsCache ._curr_region = region
133- JumpStartModelsCache ._cache = cache .JumpStartModelsCache (
134- region = region , ** JumpStartModelsCache ._cache_kwargs
134+ JumpStartModelsAccessor ._curr_region = region
135+ JumpStartModelsAccessor ._cache = cache .JumpStartModelsCache (
136+ region = region , ** JumpStartModelsAccessor ._cache_kwargs
135137 )
136138
137139 @staticmethod
@@ -146,4 +148,4 @@ def reset_cache(cache_kwargs: Dict[str, Any] = None, region: Optional[str] = Non
146148 region (str): The region to validate along with the kwargs.
147149 """
148150 cache_kwargs_dict = {} if cache_kwargs is None else cache_kwargs
149- JumpStartModelsCache .set_cache_kwargs (cache_kwargs_dict , region )
151+ JumpStartModelsAccessor .set_cache_kwargs (cache_kwargs_dict , region )
0 commit comments