@@ -136,33 +136,19 @@ def test_jumpstart_cache_get_header():
136136
137137 cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
138138
139- assert (
140- JumpStartModelHeader (
141- {
142- "model_id" : "tensorflow-ic-imagenet-inception-v3-classification-4" ,
143- "version" : "2.0.0" ,
144- "min_version" : "2.49.0" ,
145- "spec_key" : "community_models_specs/tensorflow-ic"
146- "-imagenet-inception-v3-classification-4/specs_v2.0.0.json" ,
147- }
148- )
149- == cache .get_header (model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" )
139+ assert JumpStartModelHeader (
140+ {
141+ "model_id" : "tensorflow-ic-imagenet-inception-v3-classification-4" ,
142+ "version" : "2.0.0" ,
143+ "min_version" : "2.49.0" ,
144+ "spec_key" : "community_models_specs/tensorflow-ic"
145+ "-imagenet-inception-v3-classification-4/specs_v2.0.0.json" ,
146+ }
147+ ) == cache .get_header (
148+ model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
150149 )
151150
152151 # See if we can make the same query 2 times consecutively
153- assert (
154- JumpStartModelHeader (
155- {
156- "model_id" : "tensorflow-ic-imagenet-inception-v3-classification-4" ,
157- "version" : "2.0.0" ,
158- "min_version" : "2.49.0" ,
159- "spec_key" : "community_models_specs/tensorflow-ic"
160- "-imagenet-inception-v3-classification-4/specs_v2.0.0.json" ,
161- }
162- )
163- == cache .get_header (model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" )
164- )
165-
166152 assert JumpStartModelHeader (
167153 {
168154 "model_id" : "tensorflow-ic-imagenet-inception-v3-classification-4" ,
@@ -278,6 +264,7 @@ def test_jumpstart_cache_get_header():
278264 with pytest .raises (KeyError ):
279265 cache .get_header (
280266 model_id = "tensorflow-ic-imagenet-inception-v3-classification-4-bak" ,
267+ semantic_version_str = "*" ,
281268 )
282269
283270
@@ -340,21 +327,30 @@ def test_jumpstart_cache_handles_boto3_client_errors():
340327 stubbed_s3_client .add_client_error ("get_object" , http_status_code = 404 )
341328 stubbed_s3_client .activate ()
342329 with pytest .raises (botocore .exceptions .ClientError ):
343- cache .get_header (model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" )
330+ cache .get_header (
331+ model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" ,
332+ semantic_version_str = "*" ,
333+ )
344334
345335 cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
346336 stubbed_s3_client = Stubber (cache ._s3_client )
347337 stubbed_s3_client .add_client_error ("get_object" , service_error_code = "AccessDenied" )
348338 stubbed_s3_client .activate ()
349339 with pytest .raises (botocore .exceptions .ClientError ):
350- cache .get_header (model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" )
340+ cache .get_header (
341+ model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" ,
342+ semantic_version_str = "*" ,
343+ )
351344
352345 cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
353346 stubbed_s3_client = Stubber (cache ._s3_client )
354347 stubbed_s3_client .add_client_error ("get_object" , service_error_code = "EndpointConnectionError" )
355348 stubbed_s3_client .activate ()
356349 with pytest .raises (botocore .exceptions .ClientError ):
357- cache .get_header (model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" )
350+ cache .get_header (
351+ model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" ,
352+ semantic_version_str = "*" ,
353+ )
358354
359355 # Testing head_object:
360356 mock_now = datetime .datetime .fromtimestamp (1636730651.079551 )
@@ -388,13 +384,18 @@ def test_jumpstart_cache_handles_boto3_client_errors():
388384
389385 stubbed_s3_client1 .add_response ("get_object" , copy .deepcopy (get_object_mocked_response ))
390386 stubbed_s3_client1 .activate ()
391- cache1 .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
387+ cache1 .get_header (
388+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
389+ )
392390
393391 mock_datetime .now .return_value += datetime .timedelta (weeks = 1 )
394392
395393 stubbed_s3_client1 .add_client_error ("head_object" , http_status_code = 404 )
396394 with pytest .raises (botocore .exceptions .ClientError ):
397- cache1 .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
395+ cache1 .get_header (
396+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" ,
397+ semantic_version_str = "*" ,
398+ )
398399
399400 cache2 = JumpStartModelsCache (
400401 s3_bucket_name = "some_bucket" , s3_cache_expiration_horizon = datetime .timedelta (hours = 1 )
@@ -403,13 +404,18 @@ def test_jumpstart_cache_handles_boto3_client_errors():
403404
404405 stubbed_s3_client2 .add_response ("get_object" , copy .deepcopy (get_object_mocked_response ))
405406 stubbed_s3_client2 .activate ()
406- cache2 .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
407+ cache2 .get_header (
408+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
409+ )
407410
408411 mock_datetime .now .return_value += datetime .timedelta (weeks = 1 )
409412
410413 stubbed_s3_client2 .add_client_error ("head_object" , service_error_code = "AccessDenied" )
411414 with pytest .raises (botocore .exceptions .ClientError ):
412- cache2 .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
415+ cache2 .get_header (
416+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" ,
417+ semantic_version_str = "*" ,
418+ )
413419
414420 cache3 = JumpStartModelsCache (
415421 s3_bucket_name = "some_bucket" , s3_cache_expiration_horizon = datetime .timedelta (hours = 1 )
@@ -418,15 +424,20 @@ def test_jumpstart_cache_handles_boto3_client_errors():
418424
419425 stubbed_s3_client3 .add_response ("get_object" , copy .deepcopy (get_object_mocked_response ))
420426 stubbed_s3_client3 .activate ()
421- cache3 .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
427+ cache3 .get_header (
428+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
429+ )
422430
423431 mock_datetime .now .return_value += datetime .timedelta (weeks = 1 )
424432
425433 stubbed_s3_client3 .add_client_error (
426434 "head_object" , service_error_code = "EndpointConnectionError"
427435 )
428436 with pytest .raises (botocore .exceptions .ClientError ):
429- cache3 .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
437+ cache3 .get_header (
438+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" ,
439+ semantic_version_str = "*" ,
440+ )
430441
431442
432443def test_jumpstart_cache_accepts_input_parameters ():
@@ -497,7 +508,9 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client):
497508 }
498509 mock_boto3_client .return_value .head_object .return_value = {"ETag" : "hash1" }
499510
500- cache .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
511+ cache .get_header (
512+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
513+ )
501514
502515 # first time accessing cache should just involve get_object
503516 mock_boto3_client .return_value .get_object .assert_called_with (
@@ -520,7 +533,9 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client):
520533 # invalidate cache
521534 mock_datetime .now .return_value += datetime .timedelta (hours = 2 )
522535
523- cache .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
536+ cache .get_header (
537+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
538+ )
524539
525540 mock_boto3_client .return_value .head_object .assert_called_with (
526541 Bucket = bucket_name , Key = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY
@@ -542,7 +557,9 @@ def test_jumpstart_cache_evaluates_md5_hash(mock_boto3_client):
542557 # invalidate cache
543558 mock_datetime .now .return_value += datetime .timedelta (hours = 2 )
544559
545- cache .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
560+ cache .get_header (
561+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
562+ )
546563
547564 mock_boto3_client .return_value .get_object .assert_called_with (
548565 Bucket = bucket_name , Key = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY
@@ -581,7 +598,9 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client):
581598 cache = JumpStartModelsCache (
582599 s3_bucket_name = bucket_name , s3_client_config = client_config , region = "my_region"
583600 )
584- cache .get_header (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
601+ cache .get_header (
602+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
603+ )
585604
586605 mock_boto3_client .return_value .get_object .assert_called_with (
587606 Bucket = bucket_name , Key = JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY
@@ -601,7 +620,9 @@ def test_jumpstart_cache_makes_correct_s3_calls(mock_boto3_client):
601620 ),
602621 "ETag" : "etag" ,
603622 }
604- cache .get_specs (model_id = "pytorch-ic-imagenet-inception-v3-classification-4" )
623+ cache .get_specs (
624+ model_id = "pytorch-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
625+ )
605626
606627 mock_boto3_client .return_value .get_object .assert_called_with (
607628 Bucket = bucket_name ,
@@ -633,7 +654,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache():
633654 "imagenet-inception-v3-classification-4/specs_v1.0.0.json" ,
634655 }
635656 ) == cache .get_header (
636- model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" ,
657+ model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" , semantic_version_str = "*"
637658 )
638659 cache .clear .assert_called_once ()
639660 cache .clear .reset_mock ()
@@ -649,6 +670,7 @@ def test_jumpstart_cache_handles_bad_semantic_version_manifest_key_cache():
649670 with pytest .raises (KeyError ):
650671 cache .get_header (
651672 model_id = "tensorflow-ic-imagenet-inception-v3-classification-4" ,
673+ semantic_version_str = "*" ,
652674 )
653675 cache .clear .assert_called_once ()
654676
@@ -683,9 +705,7 @@ def test_jumpstart_cache_get_specs():
683705 )
684706
685707 with pytest .raises (KeyError ):
686- cache .get_specs (
687- model_id = model_id + "bak" ,
688- )
708+ cache .get_specs (model_id = model_id + "bak" , semantic_version_str = "*" )
689709
690710 with pytest .raises (KeyError ):
691711 cache .get_specs (model_id = model_id , semantic_version_str = "9.*" )
0 commit comments