Skip to content

Commit 211f4e5

Browse files
authored
breaking: preserve script path when S3 source_dir is provided (#941)
1 parent db21a38 commit 211f4e5

File tree

13 files changed

+69
-29
lines changed

13 files changed

+69
-29
lines changed

src/sagemaker/chainer/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def create_model(
214214
return ChainerModel(
215215
self.model_data,
216216
role or self.role,
217-
entry_point or self.entry_point,
217+
entry_point or self._model_entry_point(),
218218
source_dir=(source_dir or self._model_source_dir()),
219219
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
220220
container_log_level=self.container_log_level,

src/sagemaker/estimator.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1734,17 +1734,28 @@ def _stage_user_code_in_s3(self):
17341734
)
17351735

17361736
def _model_source_dir(self):
1737-
"""Get the appropriate value to pass as source_dir to model constructor
1738-
on deploying
1737+
"""Get the appropriate value to pass as ``source_dir`` to a model constructor.
17391738
17401739
Returns:
1741-
str: Either a local or an S3 path pointing to the source_dir to be
1742-
used for code by the model to be deployed
1740+
str: Either a local or an S3 path pointing to the ``source_dir`` to be
1741+
used for code by the model to be deployed
17431742
"""
17441743
return (
17451744
self.source_dir if self.sagemaker_session.local_mode else self.uploaded_code.s3_prefix
17461745
)
17471746

1747+
def _model_entry_point(self):
1748+
"""Get the appropriate value to pass as ``entry_point`` to a model constructor.
1749+
1750+
Returns:
1751+
str: The path to the entry point script. This can be either an absolute path or
1752+
a path relative to ``self._model_source_dir()``.
1753+
"""
1754+
if self.sagemaker_session.local_mode or (self._model_source_dir() is None):
1755+
return self.entry_point
1756+
1757+
return self.uploaded_code.script_name
1758+
17481759
def hyperparameters(self):
17491760
"""Return the hyperparameters as a dictionary to use for training.
17501761

src/sagemaker/fw_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def tar_and_upload_dir(
447447
script name.
448448
"""
449449
if directory and directory.lower().startswith("s3://"):
450-
return UploadedCode(s3_prefix=directory, script_name=os.path.basename(script))
450+
return UploadedCode(s3_prefix=directory, script_name=script)
451451

452452
script_name = script if directory else os.path.basename(script)
453453
dependencies = dependencies or []

src/sagemaker/mxnet/estimator.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,10 @@ def create_model(
218218

219219
kwargs["name"] = self._get_or_create_name(kwargs.get("name"))
220220

221-
return MXNetModel(
221+
model = MXNetModel(
222222
self.model_data,
223223
role or self.role,
224-
entry_point or self.entry_point,
224+
entry_point,
225225
framework_version=self.framework_version,
226226
py_version=self.py_version,
227227
source_dir=(source_dir or self._model_source_dir()),
@@ -235,6 +235,13 @@ def create_model(
235235
**kwargs
236236
)
237237

238+
if entry_point is None:
239+
model.entry_point = (
240+
self.entry_point if model._is_mms_version() else self._model_entry_point()
241+
)
242+
243+
return model
244+
238245
@classmethod
239246
def _prepare_init_params_from_job_description(cls, job_details, model_channel_name=None):
240247
"""Convert the job description to init params that can be handled by the

src/sagemaker/pytorch/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def create_model(
175175
return PyTorchModel(
176176
self.model_data,
177177
role or self.role,
178-
entry_point or self.entry_point,
178+
entry_point or self._model_entry_point(),
179179
framework_version=self.framework_version,
180180
py_version=self.py_version,
181181
source_dir=(source_dir or self._model_source_dir()),

src/sagemaker/rl/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def create_model(
232232
if not entry_point and (source_dir or dependencies):
233233
raise AttributeError("Please provide an `entry_point`.")
234234

235-
entry_point = entry_point or self.entry_point
235+
entry_point = entry_point or self._model_entry_point()
236236
source_dir = source_dir or self._model_source_dir()
237237
dependencies = dependencies or self.dependencies
238238

src/sagemaker/sklearn/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def create_model(
196196
return SKLearnModel(
197197
self.model_data,
198198
role,
199-
entry_point or self.entry_point,
199+
entry_point or self._model_entry_point(),
200200
source_dir=(source_dir or self._model_source_dir()),
201201
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
202202
container_log_level=self.container_log_level,

src/sagemaker/xgboost/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def create_model(
172172
return XGBoostModel(
173173
self.model_data,
174174
role,
175-
entry_point or self.entry_point,
175+
entry_point or self._model_entry_point(),
176176
framework_version=self.framework_version,
177177
source_dir=(source_dir or self._model_source_dir()),
178178
enable_cloudwatch_metrics=self.enable_cloudwatch_metrics,
2.15 KB
Binary file not shown.

tests/integ/test_mxnet.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,16 @@ def mxnet_training_job(
3232
sagemaker_session, mxnet_full_version, mxnet_full_py_version, cpu_instance_type
3333
):
3434
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
35-
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")
35+
s3_prefix = "integ-test-data/mxnet_mnist"
3636
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
3737

38+
s3_source = sagemaker_session.upload_data(
39+
path=os.path.join(data_path, "sourcedir.tar.gz"), key_prefix="{}/src".format(s3_prefix)
40+
)
41+
3842
mx = MXNet(
39-
entry_point=script_path,
43+
entry_point="mxnet_mnist/mnist.py",
44+
source_dir=s3_source,
4045
role="SageMakerRole",
4146
framework_version=mxnet_full_version,
4247
py_version=mxnet_full_py_version,
@@ -46,10 +51,10 @@ def mxnet_training_job(
4651
)
4752

4853
train_input = mx.sagemaker_session.upload_data(
49-
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
54+
path=os.path.join(data_path, "train"), key_prefix="{}/train".format(s3_prefix)
5055
)
5156
test_input = mx.sagemaker_session.upload_data(
52-
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
57+
path=os.path.join(data_path, "test"), key_prefix="{}/test".format(s3_prefix)
5358
)
5459

5560
mx.fit({"train": train_input, "test": test_input})
@@ -62,7 +67,13 @@ def test_attach_deploy(mxnet_training_job, sagemaker_session, cpu_instance_type)
6267

6368
with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
6469
estimator = MXNet.attach(mxnet_training_job, sagemaker_session=sagemaker_session)
65-
predictor = estimator.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
70+
predictor = estimator.deploy(
71+
1,
72+
cpu_instance_type,
73+
entry_point="mnist.py",
74+
source_dir=os.path.join(DATA_DIR, "mxnet_mnist"),
75+
endpoint_name=endpoint_name,
76+
)
6677
data = numpy.zeros(shape=(1, 1, 28, 28))
6778
result = predictor.predict(data)
6879
assert result is not None

0 commit comments

Comments
 (0)