3535from sagemaker .jumpstart .filters import Constant , ModelFilter , Operator , evaluate_filter_expression
3636from sagemaker .jumpstart .types import JumpStartModelHeader , JumpStartModelSpecs
3737from sagemaker .jumpstart .utils import get_jumpstart_content_bucket , get_sagemaker_version
38+ from sagemaker .session import Session
3839
3940
4041def _compare_model_version_tuples ( # pylint: disable=too-many-return-statements
@@ -137,6 +138,7 @@ def extract_framework_task_model(model_id: str) -> Tuple[str, str, str]:
137138def list_jumpstart_tasks ( # pylint: disable=redefined-builtin
138139 filter : Union [Operator , str ] = Constant (BooleanValues .TRUE ),
139140 region : str = JUMPSTART_DEFAULT_REGION_NAME ,
141+ sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
140142) -> List [str ]:
141143 """List tasks for JumpStart, and optionally apply filters to result.
142144
@@ -148,10 +150,14 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin
148150 (Default: Constant(BooleanValues.TRUE)).
149151 region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
150152 models. (Default: JUMPSTART_DEFAULT_REGION_NAME).
153+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to
154+ use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
151155 """
152156
153157 tasks : Set [str ] = set ()
154- for model_id , _ in _generate_jumpstart_model_versions (filter = filter , region = region ):
158+ for model_id , _ in _generate_jumpstart_model_versions (
159+ filter = filter , region = region , sagemaker_session = sagemaker_session
160+ ):
155161 _ , task , _ = extract_framework_task_model (model_id )
156162 tasks .add (task )
157163 return sorted (list (tasks ))
@@ -160,6 +166,7 @@ def list_jumpstart_tasks( # pylint: disable=redefined-builtin
160166def list_jumpstart_frameworks ( # pylint: disable=redefined-builtin
161167 filter : Union [Operator , str ] = Constant (BooleanValues .TRUE ),
162168 region : str = JUMPSTART_DEFAULT_REGION_NAME ,
169+ sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
163170) -> List [str ]:
164171 """List frameworks for JumpStart, and optionally apply filters to result.
165172
@@ -171,10 +178,14 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
171178 (Default: Constant(BooleanValues.TRUE)).
172179 region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
173180 models. (Default: JUMPSTART_DEFAULT_REGION_NAME).
181+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session
182+ to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
174183 """
175184
176185 frameworks : Set [str ] = set ()
177- for model_id , _ in _generate_jumpstart_model_versions (filter = filter , region = region ):
186+ for model_id , _ in _generate_jumpstart_model_versions (
187+ filter = filter , region = region , sagemaker_session = sagemaker_session
188+ ):
178189 framework , _ , _ = extract_framework_task_model (model_id )
179190 frameworks .add (framework )
180191 return sorted (list (frameworks ))
@@ -183,6 +194,7 @@ def list_jumpstart_frameworks( # pylint: disable=redefined-builtin
183194def list_jumpstart_scripts ( # pylint: disable=redefined-builtin
184195 filter : Union [Operator , str ] = Constant (BooleanValues .TRUE ),
185196 region : str = JUMPSTART_DEFAULT_REGION_NAME ,
197+ sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
186198) -> List [str ]:
187199 """List scripts for JumpStart, and optionally apply filters to result.
188200
@@ -194,19 +206,24 @@ def list_jumpstart_scripts( # pylint: disable=redefined-builtin
194206 (Default: Constant(BooleanValues.TRUE)).
195207 region (str): Optional. The AWS region from which to retrieve JumpStart metadata regarding
196208 models. (Default: JUMPSTART_DEFAULT_REGION_NAME).
209+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to
210+ use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
197211 """
198212 if (isinstance (filter , Constant ) and filter .resolved_value == BooleanValues .TRUE ) or (
199213 isinstance (filter , str ) and filter .lower () == BooleanValues .TRUE .lower ()
200214 ):
201215 return sorted ([e .value for e in JumpStartScriptScope ])
202216
203217 scripts : Set [str ] = set ()
204- for model_id , version in _generate_jumpstart_model_versions (filter = filter , region = region ):
218+ for model_id , version in _generate_jumpstart_model_versions (
219+ filter = filter , region = region , sagemaker_session = sagemaker_session
220+ ):
205221 scripts .add (JumpStartScriptScope .INFERENCE )
206222 model_specs = accessors .JumpStartModelsAccessor .get_model_specs (
207223 region = region ,
208224 model_id = model_id ,
209225 version = version ,
226+ s3_client = sagemaker_session .s3_client ,
210227 )
211228 if model_specs .training_supported :
212229 scripts .add (JumpStartScriptScope .TRAINING )
@@ -222,6 +239,7 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
222239 list_incomplete_models : bool = False ,
223240 list_old_models : bool = False ,
224241 list_versions : bool = False ,
242+ sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
225243) -> List [Union [Tuple [str ], Tuple [str , str ]]]:
226244 """List models for JumpStart, and optionally apply filters to result.
227245
@@ -241,11 +259,16 @@ def list_jumpstart_models( # pylint: disable=redefined-builtin
241259 versions should be included in the returned result. (Default: False).
242260 list_versions (bool): Optional. True if versions for models should be returned in addition
243261 to the id of the model. (Default: False).
262+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use
263+ to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
244264 """
245265
246266 model_id_version_dict : Dict [str , List [str ]] = dict ()
247267 for model_id , version in _generate_jumpstart_model_versions (
248- filter = filter , region = region , list_incomplete_models = list_incomplete_models
268+ filter = filter ,
269+ region = region ,
270+ list_incomplete_models = list_incomplete_models ,
271+ sagemaker_session = sagemaker_session ,
249272 ):
250273 if model_id not in model_id_version_dict :
251274 model_id_version_dict [model_id ] = list ()
@@ -271,6 +294,7 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
271294 filter : Union [Operator , str ] = Constant (BooleanValues .TRUE ),
272295 region : str = JUMPSTART_DEFAULT_REGION_NAME ,
273296 list_incomplete_models : bool = False ,
297+ sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
274298) -> Generator :
275299 """Generate models for JumpStart, and optionally apply filters to result.
276300
@@ -286,9 +310,13 @@ def _generate_jumpstart_model_versions( # pylint: disable=redefined-builtin
286310 requested by the filter, and the filter cannot be resolved to a include/not include,
287311 whether the model should be included. By default, these models are omitted from
288312 results. (Default: False).
313+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session
314+ to use to perform the model search. (Default: DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
289315 """
290316
291- models_manifest_list = accessors .JumpStartModelsAccessor ._get_manifest (region = region )
317+ models_manifest_list = accessors .JumpStartModelsAccessor ._get_manifest (
318+ region = region , s3_client = sagemaker_session .s3_client
319+ )
292320
293321 if isinstance (filter , str ):
294322 filter = Identity (filter )
@@ -366,7 +394,7 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str,
366394
367395 model_specs = JumpStartModelSpecs (
368396 json .loads (
369- DEFAULT_JUMPSTART_SAGEMAKER_SESSION .read_s3_file (
397+ sagemaker_session .read_s3_file (
370398 get_jumpstart_content_bucket (region ), model_manifest .spec_key
371399 )
372400 )
@@ -418,7 +446,10 @@ def evaluate_model(model_manifest: JumpStartModelHeader) -> Optional[Tuple[str,
418446
419447
420448def get_model_url (
421- model_id : str , model_version : str , region : str = JUMPSTART_DEFAULT_REGION_NAME
449+ model_id : str ,
450+ model_version : str ,
451+ region : str = JUMPSTART_DEFAULT_REGION_NAME ,
452+ sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
422453) -> str :
423454 """Retrieve web url describing pretrained model.
424455
@@ -427,9 +458,14 @@ def get_model_url(
427458 model_version (str): The model version for which to retrieve the url.
428459 region (str): Optional. The region from which to retrieve metadata.
429460 (Default: JUMPSTART_DEFAULT_REGION_NAME)
461+ sagemaker_session (sagemaker.session.Session): Optional. The SageMaker Session to use
462+ to retrieve the model url.
430463 """
431464
432465 model_specs = accessors .JumpStartModelsAccessor .get_model_specs (
433- region = region , model_id = model_id , version = model_version
466+ region = region ,
467+ model_id = model_id ,
468+ version = model_version ,
469+ s3_client = sagemaker_session .s3_client ,
434470 )
435471 return model_specs .url
0 commit comments