Skip to content
This repository was archived by the owner on Aug 26, 2020. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/sagemaker_containers/_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,9 @@ def download_and_extract(uri, path): # type: (str, str) -> None
if os.path.exists(path):
shutil.rmtree(path)
shutil.move(uri, path)
elif tarfile.is_tarfile(uri):
with tarfile.open(name=uri, mode='r:gz') as t:
t.extractall(path=path)
else:
shutil.copy2(uri, path)

Expand Down
32 changes: 19 additions & 13 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,22 +150,28 @@ def add_file(self, file): # type: (File) -> UserModule
def url(self): # type: () -> str
return os.path.join('s3://', self.bucket, self.key)

def upload(self): # type: () -> UserModule
with _files.tmpdir() as tmpdir:
tar_name = os.path.join(tmpdir, 'sourcedir.tar.gz')
with tarfile.open(tar_name, mode='w:gz') as tar:
for _file in self._files:
name = os.path.join(tmpdir, _file.name)
with open(name, 'w+') as f:
def create_tar(self, dir_path=None):
dir_path = dir_path or os.path.dirname(os.path.realpath(__file__))
tar_name = os.path.join(dir_path, 'sourcedir.tar.gz')
with tarfile.open(tar_name, mode='w:gz') as tar:
for _file in self._files:
name = os.path.join(dir_path, _file.name)
with open(name, 'w+') as f:

if isinstance(_file.data, six.string_types):
data = _file.data
else:
data = '\n'.join(_file.data)

if isinstance(_file.data, six.string_types):
data = _file.data
else:
data = '\n'.join(_file.data)
f.write(data)
tar.add(name=name, arcname=_file.name)
os.remove(name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You dont need to delete the file here because everything is being written inside the tmp dir anyways.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this because that tmpdir gets deleted when the function returns if you use a tmpdir for the tar file location. We use the contents of the tar file for the test functional/test_download_and_impor.py::test_import_module_with_local_tar_via_download_and_extract so in that use case, we are using the current directory (not tmpdir)


f.write(data)
tar.add(name=name, arcname=_file.name)
return tar_name

def upload(self): # type: () -> UserModule
with _files.tmpdir() as tmpdir:
tar_name = self.create_tar(dir_path=tmpdir)
self._s3.Object(self.bucket, self.key).upload_file(tar_name)
return self

Expand Down
15 changes: 15 additions & 0 deletions test/functional/test_download_and_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from __future__ import absolute_import

import importlib
import os
import shlex
import subprocess
import textwrap
Expand Down Expand Up @@ -157,3 +158,17 @@ def test_import_module_with_s3_script_with_error(user_module_name):

with pytest.raises(errors.ImportModuleError):
modules.import_module(user_module.url, user_module_name, cache=False)


@pytest.mark.parametrize('user_module',
[test.UserModule(USER_SCRIPT_WITH_REQUIREMENTS).add_file(SETUP_FILE),
test.UserModule(USER_SCRIPT_WITH_REQUIREMENTS)])
def test_import_module_with_local_tar_via_download_and_extract(user_module, user_module_name):
user_module = user_module.add_file(REQUIREMENTS_FILE)
tar_name = user_module.create_tar()

module = modules.import_module(tar_name, name=user_module_name, cache=False)

assert module.say() == REQUIREMENTS_TXT_ASSERT_STR

os.remove(tar_name)
22 changes: 19 additions & 3 deletions test/unit/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import itertools
import logging
import os
import tarfile

from mock import mock_open, patch
import pytest
Expand Down Expand Up @@ -113,7 +114,7 @@ def test_write_failure_file():
@patch('os.path.isdir', lambda x: True)
@patch('shutil.rmtree')
@patch('shutil.move')
def test_download_and_and_extract_source_dir(move, rmtree, s3_download):
def test_download_and_extract_source_dir(move, rmtree, s3_download):
uri = _env.channel_path('code')
_files.download_and_extract(uri, _env.code_dir)
s3_download.assert_not_called()
Expand All @@ -125,9 +126,24 @@ def test_download_and_and_extract_source_dir(move, rmtree, s3_download):
@patch('sagemaker_containers._files.s3_download')
@patch('os.path.isdir', lambda x: False)
@patch('shutil.copy2')
def test_download_and_and_extract_file(copy, s3_download):
uri = _env.channel_path('code')
def test_download_and_extract_file(copy, s3_download):
uri = __file__
_files.download_and_extract(uri, _env.code_dir)

s3_download.assert_not_called()
copy.assert_called_with(uri, _env.code_dir)


@patch('sagemaker_containers._files.s3_download')
@patch('os.path.isdir', lambda x: False)
@patch('tarfile.TarFile.extractall')
def test_download_and_extract_tar(extractall, s3_download):
t = tarfile.open(name='test.tar.gz', mode='w:gz')
t.close()
uri = t.name
_files.download_and_extract(uri, _env.code_dir)

s3_download.assert_not_called()
extractall.assert_called_with(path=_env.code_dir)

os.remove(uri)
1 change: 1 addition & 0 deletions test/unit/test_intermediate_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_wrong_output():


@patch('inotify_simple.INotify', MagicMock())
@patch('boto3.client', MagicMock())
def test_daemon_process():
intemediate_sync = _intermediate_output.start_sync(S3_BUCKET, REGION)
assert intemediate_sync.daemon is True
Expand Down