@@ -268,23 +268,57 @@ def test_download_file():
268268
269269@patch ('tarfile.open' )
270270def test_create_tar_file_with_provided_path (open ):
271- open .return_value = open
272- open .__enter__ = Mock ()
273- open .__exit__ = Mock (return_value = None )
271+ files = mock_tarfile (open )
272+
274273 file_list = ['/tmp/a' , '/tmp/b' ]
274+
275275 path = sagemaker .utils .create_tar_file (file_list , target = '/my/custom/path.tar.gz' )
276276 assert path == '/my/custom/path.tar.gz'
277+ assert files == [['/tmp/a' , 'a' ], ['/tmp/b' , 'b' ]]
277278
278279
279280@patch ('tarfile.open' )
280- @patch ('tempfile.mkstemp' , Mock (return_value = (None , '/auto/generated/path' )))
281- def test_create_tar_file_with_auto_generated_path (open ):
281+ def test_create_tar_file_with_directories (open ):
282+ files = mock_tarfile (open )
283+
284+ path = sagemaker .utils .create_tar_file (dir_files = ['/tmp/a' , '/tmp/b' ],
285+ target = '/my/custom/path.tar.gz' )
286+ assert path == '/my/custom/path.tar.gz'
287+ assert files == [['/tmp/a' , '/' ], ['/tmp/b' , '/' ]]
288+
289+
290+ @patch ('tarfile.open' )
291+ def test_create_tar_file_with_files_and_directories (open ):
292+ files = mock_tarfile (open )
293+
294+ path = sagemaker .utils .create_tar_file (dir_files = ['/tmp/a' , '/tmp/b' ],
295+ source_files = ['/tmp/c' , '/tmp/d' ],
296+ target = '/my/custom/path.tar.gz' )
297+ assert path == '/my/custom/path.tar.gz'
298+ assert files == [['/tmp/c' , 'c' ], ['/tmp/d' , 'd' ], ['/tmp/a' , '/' ], ['/tmp/b' , '/' ]]
299+
300+
301+ def mock_tarfile (open ):
282302 open .return_value = open
303+ files = []
304+
305+ def add_files (filename , arcname ):
306+ files .append ([filename , arcname ])
307+
283308 open .__enter__ = Mock ()
309+ open .__enter__ ().add = add_files
284310 open .__exit__ = Mock (return_value = None )
285- file_list = ['/tmp/a' , '/tmp/b' ]
286- path = sagemaker .utils .create_tar_file (file_list )
311+ return files
312+
313+
314+ @patch ('tarfile.open' )
315+ @patch ('tempfile.mkstemp' , Mock (return_value = (None , '/auto/generated/path' )))
316+ def test_create_tar_file_with_auto_generated_path (open ):
317+ files = mock_tarfile (open )
318+
319+ path = sagemaker .utils .create_tar_file (['/tmp/a' , '/tmp/b' ])
287320 assert path == '/auto/generated/path'
321+ assert files == [['/tmp/a' , 'a' ], ['/tmp/b' , 'b' ]]
288322
289323
290324def write_file (path , content ):
0 commit comments