Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 92 additions & 54 deletions src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
import subprocess
import sys
import tempfile
from subprocess import Popen
from fcntl import fcntl, F_GETFL, F_SETFL
from six.moves.urllib.parse import urlparse
from time import sleep
from threading import Thread

import yaml

Expand Down Expand Up @@ -91,42 +91,7 @@ def train(self, input_data_config, hyperparameters):
os.mkdir(shared_dir)

data_dir = self._create_tmp_folder()
volumes = []

# Set up the channels for the containers. For local data we will
# mount the local directory to the container. For S3 Data we will download the S3 data
# first.
for channel in input_data_config:
if channel['DataSource'] and 'S3DataSource' in channel['DataSource']:
uri = channel['DataSource']['S3DataSource']['S3Uri']
elif channel['DataSource'] and 'FileDataSource' in channel['DataSource']:
uri = channel['DataSource']['FileDataSource']['FileUri']
else:
raise ValueError('Need channel[\'DataSource\'] to have [\'S3DataSource\'] or [\'FileDataSource\']')

parsed_uri = urlparse(uri)
key = parsed_uri.path.lstrip('/')

channel_name = channel['ChannelName']
channel_dir = os.path.join(data_dir, channel_name)
os.mkdir(channel_dir)

if parsed_uri.scheme == 's3':
bucket_name = parsed_uri.netloc
self._download_folder(bucket_name, key, channel_dir)
elif parsed_uri.scheme == 'file':
path = parsed_uri.path
volumes.append(_Volume(path, channel=channel_name))
else:
raise ValueError('Unknown URI scheme {}'.format(parsed_uri.scheme))

# If the training script directory is a local directory, mount it to the container.
training_dir = json.loads(hyperparameters[sagemaker.estimator.DIR_PARAM_NAME])
parsed_uri = urlparse(training_dir)
if parsed_uri.scheme == 'file':
volumes.append(_Volume(parsed_uri.path, '/opt/ml/code'))
# Also mount a directory that all the containers can access.
volumes.append(_Volume(shared_dir, '/opt/ml/shared'))
volumes = self._prepare_training_volumes(data_dir, input_data_config, hyperparameters)

# Create the configuration files for each container that we will create
# Each container will map the additional local volumes (if any).
Expand All @@ -139,7 +104,15 @@ def train(self, input_data_config, hyperparameters):
compose_command = self._compose()

_ecr_login_if_needed(self.sagemaker_session.boto_session, self.image)
_execute_and_stream_output(compose_command)
process = subprocess.Popen(compose_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

try:
_stream_output(process)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need this.

Doesn't something like this work:

>>> subprocess.check_call(["ls", "asdf"], stderr=sys.stdout.fileno())
ls: asdf: No such file or directory
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/Cellar/python/2.7.14/Frameworks/Python.framework/Versions/2.7/lib/python2.7/subprocess.py", line 186, in check_call
    raise CalledProcessError(retcode, cmd)
subprocess.CalledProcessError: Command '['ls', 'asdf']' returned non-zero exit status 1

I think the interface is a bit different in python 3, but we can work through this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talked IRL, this won't work in Jupyter because Jupyter replaces the stderr object.

except RuntimeError as e:
# _stream_output() doesn't have the command line. We will handle the exception
# which contains the exit code and append the command line to it.
msg = "Failed to run: %s, %s" % (compose_command, e.message)
raise RuntimeError(msg)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you also include e as well


s3_artifacts = self.retrieve_artifacts(compose_data)

Expand Down Expand Up @@ -196,7 +169,7 @@ def serve(self, primary_container):
additional_volumes=volumes)
compose_command = self._compose()
self.container = _HostingContainer(compose_command)
self.container.up()
self.container.start()

def stop_serving(self):
"""Stop the serving container.
Expand All @@ -205,6 +178,7 @@ def stop_serving(self):
"""
if self.container:
self.container.down()
self.container.join()
self._cleanup()
# for serving we can delete everything in the container root.
_delete_tree(self.container_root)
Expand Down Expand Up @@ -304,6 +278,47 @@ def _download_folder(self, bucket_name, prefix, target):

obj.download_file(file_path)

def _prepare_training_volumes(self, data_dir, input_data_config, hyperparameters):
shared_dir = os.path.join(self.container_root, 'shared')
volumes = []
# Set up the channels for the containers. For local data we will
# mount the local directory to the container. For S3 Data we will download the S3 data
# first.
for channel in input_data_config:
if channel['DataSource'] and 'S3DataSource' in channel['DataSource']:
uri = channel['DataSource']['S3DataSource']['S3Uri']
elif channel['DataSource'] and 'FileDataSource' in channel['DataSource']:
uri = channel['DataSource']['FileDataSource']['FileUri']
else:
raise ValueError('Need channel[\'DataSource\'] to have'
' [\'S3DataSource\'] or [\'FileDataSource\']')

parsed_uri = urlparse(uri)
key = parsed_uri.path.lstrip('/')

channel_name = channel['ChannelName']
channel_dir = os.path.join(data_dir, channel_name)
os.mkdir(channel_dir)

if parsed_uri.scheme == 's3':
bucket_name = parsed_uri.netloc
self._download_folder(bucket_name, key, channel_dir)
elif parsed_uri.scheme == 'file':
path = parsed_uri.path
volumes.append(_Volume(path, channel=channel_name))
else:
raise ValueError('Unknown URI scheme {}'.format(parsed_uri.scheme))

# If the training script directory is a local directory, mount it to the container.
training_dir = json.loads(hyperparameters[sagemaker.estimator.DIR_PARAM_NAME])
parsed_uri = urlparse(training_dir)
if parsed_uri.scheme == 'file':
volumes.append(_Volume(parsed_uri.path, '/opt/ml/code'))
# Also mount a directory that all the containers can access.
volumes.append(_Volume(shared_dir, '/opt/ml/shared'))

return volumes

def _generate_compose_file(self, command, additional_volumes=None, additional_env_vars=None):
"""Writes a config file describing a training/hosting environment.

Expand Down Expand Up @@ -452,15 +467,23 @@ def _cleanup(self):
pass


class _HostingContainer(object):
def __init__(self, command, startup_delay=5):
class _HostingContainer(Thread):
def __init__(self, command):
Thread.__init__(self)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay if Thread isn't a new style class.

self.command = command
self.startup_delay = startup_delay
self.process = None

def up(self):
self.process = Popen(self.command)
sleep(self.startup_delay)
def run(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is identical to the code in train, just with a different command. Suggest refactoring into a utility function.

self.process = subprocess.Popen(self.command,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
try:
_stream_output(self.process)
except RuntimeError as e:
# _stream_output() doesn't have the command line. We will handle the exception
# which contains the exit code and append the command line to it.
msg = "Failed to run: %s, %s" % (self.command, e.message)
raise RuntimeError(msg)

def down(self):
self.process.terminate()
Expand Down Expand Up @@ -495,26 +518,41 @@ def __init__(self, host_dir, container_dir=None, channel=None):
self.map = '{}:{}'.format(self.host_dir, self.container_dir)


def _execute_and_stream_output(cmd):
"""Execute a command and stream the output to stdout
def _stream_output(process):
"""Stream the output of a process to stdout

This function takes an existing process that will be polled for output. Both stdout and
stderr will be polled and both will be sent to sys.stdout.

Args:
cmd(str or List): either a string or a List (in Popen Format) with the command to execute.
process(subprocess.Popen): a process that has been started with
stdout=PIPE and stderr=PIPE

Returns (int): command exit code
Returns (int): process exit code
"""
if isinstance(cmd, str):
cmd = shlex.split(cmd)
process = subprocess.Popen(cmd, stdout=subprocess.PIPE)
exit_code = None

# Get the current flags for the stderr file descriptor
# And add the NONBLOCK flag to allow us to read even if there is no data.
# Since usually stderr will be empty unless there is an error.
flags = fcntl(process.stderr, F_GETFL) # get current process.stderr flags
fcntl(process.stderr, F_SETFL, flags | os.O_NONBLOCK)

while exit_code is None:
stdout = process.stdout.readline().decode("utf-8")
sys.stdout.write(stdout)
try:
stderr = process.stderr.readline().decode("utf-8")
sys.stdout.write(stderr)
except IOError:
# If there is nothing to read on stderr we will get an IOError
# this is fine.
pass

exit_code = process.poll()

if exit_code != 0:
raise Exception("Failed to run %s, exit code: %s" % (",".join(cmd), exit_code))
raise RuntimeError("Process exited with code: %s" % exit_code)

return exit_code

Expand Down
35 changes: 21 additions & 14 deletions tests/unit/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import base64
import json
import os
import subprocess

import pytest
import yaml
Expand Down Expand Up @@ -180,13 +181,16 @@ def test_retrieve_artifacts(LocalSession, tmpdir):
def test_stream_output():

# it should raise an exception if the command fails
with pytest.raises(Exception):
sagemaker.local.image._execute_and_stream_output(['ls', '/some/unknown/path'])

exit_code = sagemaker.local.image._execute_and_stream_output(['echo', 'hello'])
assert exit_code == 0

exit_code = sagemaker.local.image._execute_and_stream_output('echo hello!!!')
with pytest.raises(RuntimeError):
p = subprocess.Popen(['ls', '/some/unknown/path'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
sagemaker.local.image._stream_output(p)

p = subprocess.Popen(['echo', 'hello'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
exit_code = sagemaker.local.image._stream_output(p)
assert exit_code == 0


Expand All @@ -205,10 +209,12 @@ def test_check_output():


@patch('sagemaker.local.local_session.LocalSession')
@patch('sagemaker.local.image._execute_and_stream_output')
@patch('sagemaker.local.image._stream_output')
@patch('subprocess.Popen')
@patch('sagemaker.local.image._SageMakerContainer._cleanup')
@patch('sagemaker.local.image._SageMakerContainer._download_folder')
def test_train(_download_folder, _cleanup, _execute_and_stream_output, LocalSession, tmpdir, sagemaker_session):
def test_train(_download_folder, _cleanup, popen, _stream_output, LocalSession,
tmpdir, sagemaker_session):

directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))]
with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder',
Expand All @@ -225,7 +231,7 @@ def test_train(_download_folder, _cleanup, _execute_and_stream_output, LocalSess

docker_compose_file = os.path.join(sagemaker_container.container_root, 'docker-compose.yaml')

call_args = _execute_and_stream_output.call_args[0][0]
call_args = popen.call_args[0][0]
assert call_args is not None

expected = ['docker-compose', '-f', docker_compose_file, 'up', '--build', '--abort-on-container-exit']
Expand All @@ -241,10 +247,11 @@ def test_train(_download_folder, _cleanup, _execute_and_stream_output, LocalSess


@patch('sagemaker.local.local_session.LocalSession')
@patch('sagemaker.local.image._execute_and_stream_output')
@patch('sagemaker.local.image._stream_output')
@patch('subprocess.Popen')
@patch('sagemaker.local.image._SageMakerContainer._cleanup')
@patch('sagemaker.local.image._SageMakerContainer._download_folder')
def test_train_local_code(_download_folder, _cleanup, _execute_and_stream_output,
def test_train_local_code(_download_folder, _cleanup, popen, _stream_output,
_local_session, tmpdir, sagemaker_session):
directories = [str(tmpdir.mkdir('container-root')), str(tmpdir.mkdir('data'))]
with patch('sagemaker.local.image._SageMakerContainer._create_tmp_folder',
Expand All @@ -271,7 +278,7 @@ def test_train_local_code(_download_folder, _cleanup, _execute_and_stream_output
assert '%s:/opt/ml/shared' % shared_folder_path in volumes


@patch('sagemaker.local.image._HostingContainer.up')
@patch('sagemaker.local.image._HostingContainer.run')
@patch('shutil.copy')
@patch('shutil.copytree')
def test_serve(up, copy, copytree, tmpdir, sagemaker_session):
Expand Down Expand Up @@ -299,7 +306,7 @@ def test_serve(up, copy, copytree, tmpdir, sagemaker_session):
assert config['services'][h]['command'] == 'serve'


@patch('sagemaker.local.image._HostingContainer.up')
@patch('sagemaker.local.image._HostingContainer.run')
@patch('shutil.copy')
@patch('shutil.copytree')
def test_serve_local_code(up, copy, copytree, tmpdir, sagemaker_session):
Expand Down