From 1ad6f7ee2e6fa39452649acc6065b11192ca86a4 Mon Sep 17 00:00:00 2001 From: Wilton Wu Date: Tue, 14 May 2019 18:23:38 -0700 Subject: [PATCH 1/9] change: download_and_extract local tar file --- src/sagemaker_containers/_files.py | 3 +++ test/unit/test_files.py | 19 +++++++++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/src/sagemaker_containers/_files.py b/src/sagemaker_containers/_files.py index 302aa27..abfd31b 100644 --- a/src/sagemaker_containers/_files.py +++ b/src/sagemaker_containers/_files.py @@ -132,6 +132,9 @@ def download_and_extract(uri, path): # type: (str, str) -> None if os.path.exists(path): shutil.rmtree(path) shutil.move(uri, path) + elif tarfile.is_tarfile(uri): + with tarfile.open(name=uri, mode='r:gz') as t: + t.extractall(path=path) else: shutil.copy2(uri, path) diff --git a/test/unit/test_files.py b/test/unit/test_files.py index 92d1b90..66ec410 100644 --- a/test/unit/test_files.py +++ b/test/unit/test_files.py @@ -13,6 +13,7 @@ import itertools import logging import os +import tarfile from mock import mock_open, patch import pytest @@ -113,7 +114,7 @@ def test_write_failure_file(): @patch('os.path.isdir', lambda x: True) @patch('shutil.rmtree') @patch('shutil.move') -def test_download_and_and_extract_source_dir(move, rmtree, s3_download): +def test_download_and_extract_source_dir(move, rmtree, s3_download): uri = _env.channel_path('code') _files.download_and_extract(uri, _env.code_dir) s3_download.assert_not_called() @@ -125,9 +126,23 @@ def test_download_and_and_extract_source_dir(move, rmtree, s3_download): @patch('sagemaker_containers._files.s3_download') @patch('os.path.isdir', lambda x: False) @patch('shutil.copy2') -def test_download_and_and_extract_file(copy, s3_download): +def test_download_and_extract_file(copy, s3_download): uri = _env.channel_path('code') _files.download_and_extract(uri, _env.code_dir) s3_download.assert_not_called() copy.assert_called_with(uri, _env.code_dir) + +@patch('sagemaker_containers._files.s3_download') +@patch('os.path.isdir', lambda x: False) +@patch('tarfile.TarFile.extractall') +def test_download_and_extract_tar(extractall, s3_download): + t = tarfile.open(name='test.tar.gz', mode='w:gz') + t.close() + uri = t.name + _files.download_and_extract(uri, _env.code_dir) + + s3_download.assert_not_called() + extractall.assert_called() + + os.remove(uri) From 06a7e9a2fbc680b056bf463833aaf50ee18bdfa1 Mon Sep 17 00:00:00 2001 From: Wilton Wu Date: Tue, 14 May 2019 18:37:34 -0700 Subject: [PATCH 2/9] flake8 fix --- test/unit/test_files.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/unit/test_files.py b/test/unit/test_files.py index 66ec410..89b6bb5 100644 --- a/test/unit/test_files.py +++ b/test/unit/test_files.py @@ -133,6 +133,7 @@ def test_download_and_extract_file(copy, s3_download): s3_download.assert_not_called() copy.assert_called_with(uri, _env.code_dir) + @patch('sagemaker_containers._files.s3_download') @patch('os.path.isdir', lambda x: False) @patch('tarfile.TarFile.extractall') From 1f6eab2da5b97cec9bd385090e3eee5df75cb284 Mon Sep 17 00:00:00 2001 From: Wilton Wu Date: Wed, 15 May 2019 11:00:31 -0700 Subject: [PATCH 3/9] change: check that file exists in download_and_extract --- src/sagemaker_containers/_files.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker_containers/_files.py b/src/sagemaker_containers/_files.py index abfd31b..a6ffa55 100644 --- a/src/sagemaker_containers/_files.py +++ b/src/sagemaker_containers/_files.py @@ -132,7 +132,7 @@ def download_and_extract(uri, path): # type: (str, str) -> None if os.path.exists(path): shutil.rmtree(path) shutil.move(uri, path) - elif tarfile.is_tarfile(uri): + elif os.path.exists(uri) and tarfile.is_tarfile(uri): with tarfile.open(name=uri, mode='r:gz') as t: t.extractall(path=path) else: From 88eef65e1eabe55dc93129487970c6f6880f5ed0 Mon Sep 17 00:00:00 2001 From: Wilton Wu Date: Wed, 15 May 2019 16:23:22 -0700 Subject: [PATCH 4/9] change: add functional test for tar module, fix file test, remove file exists in download_and_extract --- src/sagemaker_containers/_files.py | 2 +- test/__init__.py | 30 ++++++++++++--------- test/functional/test_download_and_import.py | 15 +++++++++++ test/unit/test_files.py | 2 +- 4 files changed, 34 insertions(+), 15 deletions(-) diff --git a/src/sagemaker_containers/_files.py b/src/sagemaker_containers/_files.py index a6ffa55..abfd31b 100644 --- a/src/sagemaker_containers/_files.py +++ b/src/sagemaker_containers/_files.py @@ -132,7 +132,7 @@ def download_and_extract(uri, path): # type: (str, str) -> None if os.path.exists(path): shutil.rmtree(path) shutil.move(uri, path) - elif os.path.exists(uri) and tarfile.is_tarfile(uri): + elif tarfile.is_tarfile(uri): with tarfile.open(name=uri, mode='r:gz') as t: t.extractall(path=path) else: diff --git a/test/__init__.py b/test/__init__.py index 80ff03c..c02e1db 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -150,22 +150,26 @@ def add_file(self, file): # type: (File) -> UserModule def url(self): # type: () -> str return os.path.join('s3://', self.bucket, self.key) - def upload(self): # type: () -> UserModule - with _files.tmpdir() as tmpdir: - tar_name = os.path.join(tmpdir, 'sourcedir.tar.gz') - with tarfile.open(tar_name, mode='w:gz') as tar: - for _file in self._files: - name = os.path.join(tmpdir, _file.name) - with open(name, 'w+') as f: + def create_tar(self, dir=os.getcwd()): + tar_name = os.path.join(dir, 'sourcedir.tar.gz') + with tarfile.open(tar_name, mode='w:gz') as tar: + for _file in self._files: + name = os.path.join(dir, _file.name) + with open(name, 'w+') as f: + + if isinstance(_file.data, six.string_types): + data = _file.data + else: + data = '\n'.join(_file.data) - if isinstance(_file.data, six.string_types): - data = _file.data - else: - data = '\n'.join(_file.data) + f.write(data) + tar.add(name=name, arcname=_file.name) - f.write(data) - tar.add(name=name, arcname=_file.name) + return tar_name + def upload(self): # type: () -> UserModule + with _files.tmpdir() as tmpdir: + tar_name = self.create_tar(dir=tmpdir) self._s3.Object(self.bucket, self.key).upload_file(tar_name) return self diff --git a/test/functional/test_download_and_import.py b/test/functional/test_download_and_import.py index a882347..adfabc0 100644 --- a/test/functional/test_download_and_import.py +++ b/test/functional/test_download_and_import.py @@ -19,6 +19,7 @@ import pytest import six +import os from sagemaker_containers.beta.framework import errors, modules import test @@ -157,3 +158,17 @@ def test_import_module_with_s3_script_with_error(user_module_name): with pytest.raises(errors.ImportModuleError): modules.import_module(user_module.url, user_module_name, cache=False) + + +@pytest.mark.parametrize('user_module', + [test.UserModule(USER_SCRIPT_WITH_REQUIREMENTS).add_file(SETUP_FILE), + test.UserModule(USER_SCRIPT_WITH_REQUIREMENTS)]) +def test_import_module_with_local_tar_via_download_and_extract(user_module, user_module_name): + user_module = user_module.add_file(REQUIREMENTS_FILE) + tar_name = user_module.create_tar() + + module = modules.import_module(tar_name, name=user_module_name, cache=False) + + assert module.say() == REQUIREMENTS_TXT_ASSERT_STR + + os.remove(tar_name) diff --git a/test/unit/test_files.py b/test/unit/test_files.py index 89b6bb5..6ef8001 100644 --- a/test/unit/test_files.py +++ b/test/unit/test_files.py @@ -127,7 +127,7 @@ def test_download_and_extract_source_dir(move, rmtree, s3_download): @patch('os.path.isdir', lambda x: False) @patch('shutil.copy2') def test_download_and_extract_file(copy, s3_download): - uri = _env.channel_path('code') + uri = __file__ _files.download_and_extract(uri, _env.code_dir) s3_download.assert_not_called() From 827a843695e559d6482099e03ffc2b5a61f72bee Mon Sep 17 00:00:00 2001 From: Wilton Wu Date: Wed, 15 May 2019 16:27:27 -0700 Subject: [PATCH 5/9] flake8 --- test/functional/test_download_and_import.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/functional/test_download_and_import.py b/test/functional/test_download_and_import.py index adfabc0..c2e1dc9 100644 --- a/test/functional/test_download_and_import.py +++ b/test/functional/test_download_and_import.py @@ -16,10 +16,10 @@ import shlex import subprocess import textwrap +import os import pytest import six -import os from sagemaker_containers.beta.framework import errors, modules import test From 4a061e4a77a43de04e9212d8223acdedd460fe1f Mon Sep 17 00:00:00 2001 From: Wilton Wu Date: Wed, 15 May 2019 16:32:36 -0700 Subject: [PATCH 6/9] flake8 --- test/functional/test_download_and_import.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/functional/test_download_and_import.py b/test/functional/test_download_and_import.py index c2e1dc9..3affcf8 100644 --- a/test/functional/test_download_and_import.py +++ b/test/functional/test_download_and_import.py @@ -13,10 +13,10 @@ from __future__ import absolute_import import importlib +import os import shlex import subprocess import textwrap -import os import pytest import six From ced6a4e53c9e502eab1c71cf43f5cc45dbc6dd26 Mon Sep 17 00:00:00 2001 From: Wilton Wu Date: Wed, 15 May 2019 16:44:53 -0700 Subject: [PATCH 7/9] change: ecs credential error fix --- test/unit/test_intermediate_output.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/unit/test_intermediate_output.py b/test/unit/test_intermediate_output.py index cf451ca..e0baefc 100644 --- a/test/unit/test_intermediate_output.py +++ b/test/unit/test_intermediate_output.py @@ -37,6 +37,7 @@ def test_wrong_output(): @patch('inotify_simple.INotify', MagicMock()) +@patch('boto3.client', MagicMock()) def test_daemon_process(): intemediate_sync = _intermediate_output.start_sync(S3_BUCKET, REGION) assert intemediate_sync.daemon is True From 625084e0093eacbf466dd0eecae91e7f02b3e94c Mon Sep 17 00:00:00 2001 From: Wilton Wu Date: Wed, 15 May 2019 22:11:04 -0700 Subject: [PATCH 8/9] change: remove transient failure in test_download_and_import --- test/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/__init__.py b/test/__init__.py index c02e1db..8b4d7d8 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -150,7 +150,7 @@ def add_file(self, file): # type: (File) -> UserModule def url(self): # type: () -> str return os.path.join('s3://', self.bucket, self.key) - def create_tar(self, dir=os.getcwd()): + def create_tar(self, dir=os.path.dirname(os.path.realpath(__file__))): tar_name = os.path.join(dir, 'sourcedir.tar.gz') with tarfile.open(tar_name, mode='w:gz') as tar: for _file in self._files: @@ -164,6 +164,7 @@ def create_tar(self, dir=os.getcwd()): f.write(data) tar.add(name=name, arcname=_file.name) + os.remove(name) return tar_name From ac90ff06635e050a6f1aeeb0ea3263f94ebcf39c Mon Sep 17 00:00:00 2001 From: Wilton Wu Date: Thu, 16 May 2019 15:33:54 -0700 Subject: [PATCH 9/9] change: fix dir in create_tar --- test/__init__.py | 9 +++++---- test/unit/test_files.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index 8b4d7d8..ca696d5 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -150,11 +150,12 @@ def add_file(self, file): # type: (File) -> UserModule def url(self): # type: () -> str return os.path.join('s3://', self.bucket, self.key) - def create_tar(self, dir=os.path.dirname(os.path.realpath(__file__))): - tar_name = os.path.join(dir, 'sourcedir.tar.gz') + def create_tar(self, dir_path=None): + dir_path = dir_path or os.path.dirname(os.path.realpath(__file__)) + tar_name = os.path.join(dir_path, 'sourcedir.tar.gz') with tarfile.open(tar_name, mode='w:gz') as tar: for _file in self._files: - name = os.path.join(dir, _file.name) + name = os.path.join(dir_path, _file.name) with open(name, 'w+') as f: if isinstance(_file.data, six.string_types): @@ -170,7 +171,7 @@ def create_tar(self, dir=os.path.dirname(os.path.realpath(__file__))): def upload(self): # type: () -> UserModule with _files.tmpdir() as tmpdir: - tar_name = self.create_tar(dir=tmpdir) + tar_name = self.create_tar(dir_path=tmpdir) self._s3.Object(self.bucket, self.key).upload_file(tar_name) return self diff --git a/test/unit/test_files.py b/test/unit/test_files.py index 6ef8001..713e26d 100644 --- a/test/unit/test_files.py +++ b/test/unit/test_files.py @@ -144,6 +144,6 @@ def test_download_and_extract_tar(extractall, s3_download): _files.download_and_extract(uri, _env.code_dir) s3_download.assert_not_called() - extractall.assert_called() + extractall.assert_called_with(path=_env.code_dir) os.remove(uri)