Skip to content

Commit e00f823

Browse files
authored
reuse sagemaker-inference's requirements.txt installation logic (#150)
Closes #149
1 parent 4d3ae4b commit e00f823

File tree

2 files changed

+10
-24
lines changed

2 files changed

+10
-24
lines changed

src/sagemaker_pytorch_serving_container/torchserve.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import os
1717
import signal
1818
import subprocess
19-
import sys
2019

2120
import pkg_resources
2221
import psutil
@@ -25,8 +24,7 @@
2524

2625
import sagemaker_pytorch_serving_container
2726
from sagemaker_pytorch_serving_container import ts_environment
28-
from sagemaker_inference import environment, utils
29-
from sagemaker_inference.environment import code_dir
27+
from sagemaker_inference import environment, utils, model_server
3028

3129
logger = logging.getLogger()
3230

@@ -47,7 +45,6 @@
4745
MODEL_STORE = "/" if ENABLE_MULTI_MODEL else os.path.join(os.getcwd(), ".sagemaker", "ts", "models")
4846

4947
PYTHON_PATH_ENV = "PYTHONPATH"
50-
REQUIREMENTS_PATH = os.path.join(code_dir, "requirements.txt")
5148
TS_NAMESPACE = "org.pytorch.serve.ModelServer"
5249

5350

@@ -78,8 +75,8 @@ def start_torchserve(handler_service=DEFAULT_HANDLER_SERVICE):
7875

7976
_create_torchserve_config_file(handler_service)
8077

81-
if os.path.exists(REQUIREMENTS_PATH):
82-
_install_requirements()
78+
if os.path.exists(model_server.REQUIREMENTS_PATH):
79+
model_server._install_requirements()
8380

8481
ts_torchserve_cmd = [
8582
"torchserve",
@@ -181,17 +178,6 @@ def _terminate(signo, frame): # pylint: disable=unused-argument
181178
signal.signal(signal.SIGTERM, _terminate)
182179

183180

184-
def _install_requirements():
185-
logger.info("installing packages from requirements.txt...")
186-
pip_install_cmd = [sys.executable, "-m", "pip", "install", "-r", REQUIREMENTS_PATH]
187-
188-
try:
189-
subprocess.check_call(pip_install_cmd)
190-
except subprocess.CalledProcessError:
191-
logger.exception("failed to install required packages, exiting")
192-
raise ValueError("failed to install required packages")
193-
194-
195181
# retry for 10 seconds
196182
@retry(stop_max_delay=10 * 1000)
197183
def _retrieve_ts_server_process():

test/unit/test_model_server.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
from mock import Mock, patch
2020
import pytest
2121

22-
from sagemaker_inference import environment
22+
from sagemaker_inference import environment, model_server
2323
from sagemaker_pytorch_serving_container import torchserve
24-
from sagemaker_pytorch_serving_container.torchserve import TS_NAMESPACE, REQUIREMENTS_PATH
24+
from sagemaker_pytorch_serving_container.torchserve import TS_NAMESPACE
2525

2626
PYTHON_PATH = "python_path"
2727
DEFAULT_CONFIGURATION = "default_configuration"
@@ -31,7 +31,7 @@
3131
@patch("subprocess.Popen")
3232
@patch("sagemaker_pytorch_serving_container.torchserve._retrieve_ts_server_process")
3333
@patch("sagemaker_pytorch_serving_container.torchserve._add_sigterm_handler")
34-
@patch("sagemaker_pytorch_serving_container.torchserve._install_requirements")
34+
@patch("sagemaker_inference.model_server._install_requirements")
3535
@patch("os.path.exists", return_value=True)
3636
@patch("sagemaker_pytorch_serving_container.torchserve._create_torchserve_config_file")
3737
@patch("sagemaker_pytorch_serving_container.torchserve._set_python_path")
@@ -72,7 +72,7 @@ def test_start_torchserve_default_service_handler(
7272
@patch("subprocess.Popen")
7373
@patch("sagemaker_pytorch_serving_container.torchserve._retrieve_ts_server_process")
7474
@patch("sagemaker_pytorch_serving_container.torchserve._add_sigterm_handler")
75-
@patch("sagemaker_pytorch_serving_container.torchserve._install_requirements")
75+
@patch("sagemaker_inference.model_server._install_requirements")
7676
@patch("os.path.exists", return_value=True)
7777
@patch("sagemaker_pytorch_serving_container.torchserve._create_torchserve_config_file")
7878
@patch("sagemaker_pytorch_serving_container.torchserve._set_python_path")
@@ -92,7 +92,7 @@ def test_start_torchserve_default_service_handler_multi_model(
9292

9393
set_python_path.assert_called_once_with()
9494
create_config.assert_called_once_with(torchserve.DEFAULT_HANDLER_SERVICE)
95-
exists.assert_called_once_with(REQUIREMENTS_PATH)
95+
exists.assert_called_once_with(model_server.REQUIREMENTS_PATH)
9696
install_requirements.assert_called_once_with()
9797

9898
ts_model_server_cmd = [
@@ -210,15 +210,15 @@ def test_add_sigterm_handler(signal_call):
210210

211211
@patch("subprocess.check_call")
212212
def test_install_requirements(check_call):
213-
torchserve._install_requirements()
213+
model_server._install_requirements()
214214
for i in ['pip', 'install', '-r', '/opt/ml/model/code/requirements.txt']:
215215
assert i in check_call.call_args.args[0]
216216

217217

218218
@patch("subprocess.check_call", side_effect=subprocess.CalledProcessError(0, "cmd"))
219219
def test_install_requirements_installation_failed(check_call):
220220
with pytest.raises(ValueError) as e:
221-
torchserve._install_requirements()
221+
model_server._install_requirements()
222222
assert "failed to install required packages" in str(e.value)
223223

224224

0 commit comments

Comments
 (0)