diff --git a/src/sagemaker_containers/_files.py b/src/sagemaker_containers/_files.py index 302aa27..abfd31b 100644 --- a/src/sagemaker_containers/_files.py +++ b/src/sagemaker_containers/_files.py @@ -132,6 +132,9 @@ def download_and_extract(uri, path): # type: (str, str) -> None if os.path.exists(path): shutil.rmtree(path) shutil.move(uri, path) + elif tarfile.is_tarfile(uri): + with tarfile.open(name=uri, mode='r:gz') as t: + t.extractall(path=path) else: shutil.copy2(uri, path) diff --git a/test/__init__.py b/test/__init__.py index 80ff03c..ca696d5 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -150,22 +150,28 @@ def add_file(self, file): # type: (File) -> UserModule def url(self): # type: () -> str return os.path.join('s3://', self.bucket, self.key) - def upload(self): # type: () -> UserModule - with _files.tmpdir() as tmpdir: - tar_name = os.path.join(tmpdir, 'sourcedir.tar.gz') - with tarfile.open(tar_name, mode='w:gz') as tar: - for _file in self._files: - name = os.path.join(tmpdir, _file.name) - with open(name, 'w+') as f: + def create_tar(self, dir_path=None): + dir_path = dir_path or os.path.dirname(os.path.realpath(__file__)) + tar_name = os.path.join(dir_path, 'sourcedir.tar.gz') + with tarfile.open(tar_name, mode='w:gz') as tar: + for _file in self._files: + name = os.path.join(dir_path, _file.name) + with open(name, 'w+') as f: + + if isinstance(_file.data, six.string_types): + data = _file.data + else: + data = '\n'.join(_file.data) - if isinstance(_file.data, six.string_types): - data = _file.data - else: - data = '\n'.join(_file.data) + f.write(data) + tar.add(name=name, arcname=_file.name) + os.remove(name) - f.write(data) - tar.add(name=name, arcname=_file.name) + return tar_name + def upload(self): # type: () -> UserModule + with _files.tmpdir() as tmpdir: + tar_name = self.create_tar(dir_path=tmpdir) self._s3.Object(self.bucket, self.key).upload_file(tar_name) return self diff --git a/test/functional/test_download_and_import.py b/test/functional/test_download_and_import.py index a882347..3affcf8 100644 --- a/test/functional/test_download_and_import.py +++ b/test/functional/test_download_and_import.py @@ -13,6 +13,7 @@ from __future__ import absolute_import import importlib +import os import shlex import subprocess import textwrap @@ -157,3 +158,17 @@ def test_import_module_with_s3_script_with_error(user_module_name): with pytest.raises(errors.ImportModuleError): modules.import_module(user_module.url, user_module_name, cache=False) + + +@pytest.mark.parametrize('user_module', + [test.UserModule(USER_SCRIPT_WITH_REQUIREMENTS).add_file(SETUP_FILE), + test.UserModule(USER_SCRIPT_WITH_REQUIREMENTS)]) +def test_import_module_with_local_tar_via_download_and_extract(user_module, user_module_name): + user_module = user_module.add_file(REQUIREMENTS_FILE) + tar_name = user_module.create_tar() + + module = modules.import_module(tar_name, name=user_module_name, cache=False) + + assert module.say() == REQUIREMENTS_TXT_ASSERT_STR + + os.remove(tar_name) diff --git a/test/unit/test_files.py b/test/unit/test_files.py index 92d1b90..713e26d 100644 --- a/test/unit/test_files.py +++ b/test/unit/test_files.py @@ -13,6 +13,7 @@ import itertools import logging import os +import tarfile from mock import mock_open, patch import pytest @@ -113,7 +114,7 @@ def test_write_failure_file(): @patch('os.path.isdir', lambda x: True) @patch('shutil.rmtree') @patch('shutil.move') -def test_download_and_and_extract_source_dir(move, rmtree, s3_download): +def test_download_and_extract_source_dir(move, rmtree, s3_download): uri = _env.channel_path('code') _files.download_and_extract(uri, _env.code_dir) s3_download.assert_not_called() @@ -125,9 +126,24 @@ def test_download_and_and_extract_source_dir(move, rmtree, s3_download): @patch('sagemaker_containers._files.s3_download') @patch('os.path.isdir', lambda x: False) @patch('shutil.copy2') -def test_download_and_and_extract_file(copy, s3_download): - uri = _env.channel_path('code') +def test_download_and_extract_file(copy, s3_download): + uri = __file__ _files.download_and_extract(uri, _env.code_dir) s3_download.assert_not_called() copy.assert_called_with(uri, _env.code_dir) + + +@patch('sagemaker_containers._files.s3_download') +@patch('os.path.isdir', lambda x: False) +@patch('tarfile.TarFile.extractall') +def test_download_and_extract_tar(extractall, s3_download): + t = tarfile.open(name='test.tar.gz', mode='w:gz') + t.close() + uri = t.name + _files.download_and_extract(uri, _env.code_dir) + + s3_download.assert_not_called() + extractall.assert_called_with(path=_env.code_dir) + + os.remove(uri) diff --git a/test/unit/test_intermediate_output.py b/test/unit/test_intermediate_output.py index cf451ca..e0baefc 100644 --- a/test/unit/test_intermediate_output.py +++ b/test/unit/test_intermediate_output.py @@ -37,6 +37,7 @@ def test_wrong_output(): @patch('inotify_simple.INotify', MagicMock()) +@patch('boto3.client', MagicMock()) def test_daemon_process(): intemediate_sync = _intermediate_output.start_sync(S3_BUCKET, REGION) assert intemediate_sync.daemon is True