2727from mock import call , patch , Mock , MagicMock
2828
2929import sagemaker
30+ from sagemaker .session_settings import SessionSettings
3031
3132BUCKET_WITHOUT_WRITING_PERMISSION = "s3://bucket-without-writing-permission"
3233
@@ -390,6 +391,13 @@ def test_repack_model_without_source_dir(tmp, fake_s3):
390391 "/code/inference.py" ,
391392 }
392393
394+ extra_args = {"ServerSideEncryption" : "aws:kms" }
395+ object_mock = fake_s3 .object_mock
396+ _ , _ , kwargs = object_mock .mock_calls [0 ]
397+
398+ assert "ExtraArgs" in kwargs
399+ assert kwargs ["ExtraArgs" ] == extra_args
400+
393401
394402def test_repack_model_with_entry_point_without_path_without_source_dir (tmp , fake_s3 ):
395403
@@ -415,12 +423,20 @@ def test_repack_model_with_entry_point_without_path_without_source_dir(tmp, fake
415423 "s3://fake/location" ,
416424 "s3://destination-bucket/model.tar.gz" ,
417425 fake_s3 .sagemaker_session ,
426+ kms_key = "kms_key" ,
418427 )
419428 finally :
420429 os .chdir (cwd )
421430
422431 assert list_tar_files (fake_s3 .fake_upload_path , tmp ) == {"/code/inference.py" , "/model" }
423432
433+ extra_args = {"ServerSideEncryption" : "aws:kms" , "SSEKMSKeyId" : "kms_key" }
434+ object_mock = fake_s3 .object_mock
435+ _ , _ , kwargs = object_mock .mock_calls [0 ]
436+
437+ assert "ExtraArgs" in kwargs
438+ assert kwargs ["ExtraArgs" ] == extra_args
439+
424440
425441def test_repack_model_from_s3_to_s3 (tmp , fake_s3 ):
426442
@@ -434,6 +450,7 @@ def test_repack_model_from_s3_to_s3(tmp, fake_s3):
434450 )
435451
436452 fake_s3 .tar_and_upload ("model-dir" , "s3://fake/location" )
453+ fake_s3 .sagemaker_session .settings = SessionSettings (encrypt_repacked_artifacts = False )
437454
438455 sagemaker .utils .repack_model (
439456 "inference.py" ,
@@ -450,6 +467,11 @@ def test_repack_model_from_s3_to_s3(tmp, fake_s3):
450467 "/model" ,
451468 }
452469
470+ object_mock = fake_s3 .object_mock
471+ _ , _ , kwargs = object_mock .mock_calls [0 ]
472+ assert "ExtraArgs" in kwargs
473+ assert kwargs ["ExtraArgs" ] is None
474+
453475
454476def test_repack_model_from_file_to_file (tmp ):
455477 create_file_tree (tmp , ["model" , "dependencies/a" , "source-dir/inference.py" ])
@@ -581,11 +603,15 @@ def __init__(self, tmp):
581603 self .sagemaker_session = MagicMock ()
582604 self .location_map = {}
583605 self .current_bucket = None
606+ self .object_mock = MagicMock ()
584607
585608 self .sagemaker_session .boto_session .resource ().Bucket ().download_file .side_effect = (
586609 self .download_file
587610 )
588611 self .sagemaker_session .boto_session .resource ().Bucket .side_effect = self .bucket
612+ self .sagemaker_session .boto_session .resource ().Object = Mock (
613+ name = "boto_session" , region_name = "us-west-2"
614+ )
589615 self .fake_upload_path = self .mock_s3_upload ()
590616
591617 def bucket (self , name ):
@@ -606,6 +632,7 @@ def tar_and_upload(self, path, fake_location):
606632
607633 def mock_s3_upload (self ):
608634 dst = os .path .join (self .tmp , "dst" )
635+ object_mock = self .object_mock
609636
610637 class MockS3Object (object ):
611638 def __init__ (self , bucket , key ):
@@ -616,6 +643,7 @@ def upload_file(self, target, **kwargs):
616643 if self .bucket in BUCKET_WITHOUT_WRITING_PERMISSION :
617644 raise exceptions .S3UploadFailedError ()
618645 shutil .copy2 (target , dst )
646+ object_mock .upload_file (target , ** kwargs )
619647
620648 self .sagemaker_session .boto_session .resource ().Object = MockS3Object
621649 return dst
0 commit comments