2727from mock import call , patch , Mock , MagicMock
2828
2929import sagemaker
30+ from sagemaker .session_settings import SessionSettings
3031
3132BUCKET_WITHOUT_WRITING_PERMISSION = "s3://bucket-without-writing-permission"
3233
@@ -390,6 +391,13 @@ def test_repack_model_without_source_dir(tmp, fake_s3):
390391 "/code/inference.py" ,
391392 }
392393
394+ extra_args = {"ServerSideEncryption" : "aws:kms" }
395+ object_mock = fake_s3 .object_mock
396+ _ , _ , kwargs = object_mock .mock_calls [0 ]
397+
398+ assert "ExtraArgs" in kwargs
399+ assert kwargs ["ExtraArgs" ] == extra_args
400+
393401
394402def test_repack_model_with_entry_point_without_path_without_source_dir (tmp , fake_s3 ):
395403
@@ -415,12 +423,20 @@ def test_repack_model_with_entry_point_without_path_without_source_dir(tmp, fake
415423 "s3://fake/location" ,
416424 "s3://destination-bucket/model.tar.gz" ,
417425 fake_s3 .sagemaker_session ,
426+ kms_key = "kms_key" ,
418427 )
419428 finally :
420429 os .chdir (cwd )
421430
422431 assert list_tar_files (fake_s3 .fake_upload_path , tmp ) == {"/code/inference.py" , "/model" }
423432
433+ extra_args = {"ServerSideEncryption" : "aws:kms" , "SSEKMSKeyId" : "kms_key" }
434+ object_mock = fake_s3 .object_mock
435+ _ , _ , kwargs = object_mock .mock_calls [0 ]
436+
437+ assert "ExtraArgs" in kwargs
438+ assert kwargs ["ExtraArgs" ] == extra_args
439+
424440
425441def test_repack_model_from_s3_to_s3 (tmp , fake_s3 ):
426442
@@ -434,6 +450,7 @@ def test_repack_model_from_s3_to_s3(tmp, fake_s3):
434450 )
435451
436452 fake_s3 .tar_and_upload ("model-dir" , "s3://fake/location" )
453+ fake_s3 .sagemaker_session .settings = SessionSettings (encrypt_repacked_artifacts = False )
437454
438455 sagemaker .utils .repack_model (
439456 "inference.py" ,
@@ -450,6 +467,11 @@ def test_repack_model_from_s3_to_s3(tmp, fake_s3):
450467 "/model" ,
451468 }
452469
470+ object_mock = fake_s3 .object_mock
471+ _ , _ , kwargs = object_mock .mock_calls [0 ]
472+ assert "ExtraArgs" in kwargs
473+ assert kwargs ["ExtraArgs" ] is None
474+
453475
454476def test_repack_model_from_file_to_file (tmp ):
455477 create_file_tree (tmp , ["model" , "dependencies/a" , "source-dir/inference.py" ])
@@ -581,6 +603,7 @@ def __init__(self, tmp):
581603 self .sagemaker_session = MagicMock ()
582604 self .location_map = {}
583605 self .current_bucket = None
606+ self .object_mock = MagicMock ()
584607
585608 self .sagemaker_session .boto_session .resource ().Bucket ().download_file .side_effect = (
586609 self .download_file
@@ -606,6 +629,7 @@ def tar_and_upload(self, path, fake_location):
606629
607630 def mock_s3_upload (self ):
608631 dst = os .path .join (self .tmp , "dst" )
632+ object_mock = self .object_mock
609633
610634 class MockS3Object (object ):
611635 def __init__ (self , bucket , key ):
@@ -616,6 +640,7 @@ def upload_file(self, target, **kwargs):
616640 if self .bucket in BUCKET_WITHOUT_WRITING_PERMISSION :
617641 raise exceptions .S3UploadFailedError ()
618642 shutil .copy2 (target , dst )
643+ object_mock .upload_file (target , ** kwargs )
619644
620645 self .sagemaker_session .boto_session .resource ().Object = MockS3Object
621646 return dst
0 commit comments